From 5f4604cc5767af635650f28e3765983f28f97d44 Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Tue, 26 Aug 2025 21:38:49 -0400 Subject: [PATCH 01/13] add files --- .../datarobot_drum/drum/args_parser.py | 34 +++++++++++++++++ .../datarobot_drum/drum/description.py | 2 +- .../datarobot_drum/drum/drum.py | 14 +++++++ .../datarobot_drum/drum/main.py | 11 ++++++ .../drum/root_predictors/prediction_server.py | 37 +++++++++++++++++-- custom_model_runner/requirements.txt | 2 + 6 files changed, 96 insertions(+), 4 deletions(-) diff --git a/custom_model_runner/datarobot_drum/drum/args_parser.py b/custom_model_runner/datarobot_drum/drum/args_parser.py index 8860e4218..864869feb 100644 --- a/custom_model_runner/datarobot_drum/drum/args_parser.py +++ b/custom_model_runner/datarobot_drum/drum/args_parser.py @@ -122,6 +122,16 @@ def _reg_args_lazy_loading_file(*parsers): help="Path to a lazy loading values file. (env: LAZY_LOADING_FILE)", ) + @staticmethod + def _reg_arg_server_type(*parsers): + for parser in parsers: + parser.add_argument( + '--server-type', + type=str, + default=None, + help='Type of server to run (optional, string)' + ) + @staticmethod def _reg_arg_output(*parsers): for parser in parsers: @@ -724,6 +734,27 @@ def _reg_arg_triton_server_access(*parsers): help="NVIDIA Triton Inference Server GRPC port", ) + @staticmethod + def _reg_arg_gunicorn_options(*parsers): + for parser in parsers: + parser.add_argument('--gunicorn-backlog', type=int, default=2024, help='Gunicorn backlog') + parser.add_argument('--gunicorn-timeout', type=int, default=120, help='Gunicorn timeout') + parser.add_argument('--gunicorn-graceful-timeout', type=int, default=30, help='Gunicorn graceful timeout') + parser.add_argument('--gunicorn-keep-alive', type=int, default=5, help='Gunicorn keep alive') + parser.add_argument('--gunicorn-max-requests', type=int, default=2000, help='Gunicorn max requests') + parser.add_argument('--gunicorn-max-requests-jitter', type=int, default=500, + help='Gunicorn max requests jitter') + parser.add_argument('--gunicorn-log-level', type=str, default='info', help='Gunicorn log level') + parser.add_argument('--gunicorn-access-logfile', type=str, default='-', help='Gunicorn access logfile') + parser.add_argument('--gunicorn-error-logfile', type=str, default='-', help='Gunicorn error logfile') + parser.add_argument('--gunicorn-access-logformat', type=str, + default='%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"', + help='Gunicorn access log format') + parser.add_argument('--gunicorn-workers', type=int, default=None, + help='Gunicorn number of workers (overrides --max-workers)') + parser.add_argument('--gunicorn-worker-class', type=str, default=None, + help='Gunicorn worker class (e.g. sync, gevent, eventlet, etc.)') + @staticmethod def _register_subcommand_perf_test(subparsers): desc = """ @@ -1026,6 +1057,9 @@ def get_arg_parser(): score_parser, server_parser, validation_parser ) + CMRunnerArgsRegistry._reg_arg_server_type(server_parser) + CMRunnerArgsRegistry._reg_arg_gunicorn_options(server_parser) + return parser @staticmethod diff --git a/custom_model_runner/datarobot_drum/drum/description.py b/custom_model_runner/datarobot_drum/drum/description.py index 83a754e6c..8a88317f6 100644 --- a/custom_model_runner/datarobot_drum/drum/description.py +++ b/custom_model_runner/datarobot_drum/drum/description.py @@ -4,6 +4,6 @@ This is proprietary source code of DataRobot, Inc. and its affiliates. Released under the terms of DataRobot Tool and Utility Agreement. """ -version = "1.16.24" +version = "1.16.25" __version__ = version project_name = "datarobot-drum" diff --git a/custom_model_runner/datarobot_drum/drum/drum.py b/custom_model_runner/datarobot_drum/drum/drum.py index 37df6da5e..efb3a6887 100644 --- a/custom_model_runner/datarobot_drum/drum/drum.py +++ b/custom_model_runner/datarobot_drum/drum/drum.py @@ -771,6 +771,20 @@ def _prepare_prediction_server_or_batch_pipeline(self, run_language): "target_type": self.target_type.value, "user_secrets_mount_path": getattr(options, "user_secrets_mount_path", None), "user_secrets_prefix": getattr(options, "user_secrets_prefix", None), + "server_type": getattr(options, "server_type", None), + # Gunicorn options + "gunicorn_backlog": getattr(options, "gunicorn_backlog", None), + "gunicorn_timeout": getattr(options, "gunicorn_timeout", None), + "gunicorn_graceful_timeout": getattr(options, "gunicorn_graceful_timeout", None), + "gunicorn_keep_alive": getattr(options, "gunicorn_keep_alive", None), + "gunicorn_max_requests": getattr(options, "gunicorn_max_requests", None), + "gunicorn_max_requests_jitter": getattr(options, "gunicorn_max_requests_jitter", None), + "gunicorn_log_level": getattr(options, "gunicorn_log_level", None), + "gunicorn_access_logfile": getattr(options, "gunicorn_access_logfile", None), + "gunicorn_error_logfile": getattr(options, "gunicorn_error_logfile", None), + "gunicorn_access_logformat": getattr(options, "gunicorn_access_logformat", None), + "gunicorn_workers": getattr(options, "gunicorn_workers", None), + "gunicorn_worker_class": getattr(options, "gunicorn_worker_class", None), } if self.run_mode == RunMode.SCORE: diff --git a/custom_model_runner/datarobot_drum/drum/main.py b/custom_model_runner/datarobot_drum/drum/main.py index 94aa0e5b5..b5775818a 100644 --- a/custom_model_runner/datarobot_drum/drum/main.py +++ b/custom_model_runner/datarobot_drum/drum/main.py @@ -43,6 +43,17 @@ import signal import sys +# Monkey patching for gevent compatibility if running with gunicorn-gevent +if ( + "gunicorn-gevent" in sys.argv or os.environ.get("SERVER_TYPE") == "gunicorn-gevent" +): + try: + from gevent import monkey + + monkey.patch_all() + except ImportError: + pass + from datarobot_drum.drum.common import config_logging, setup_otel from datarobot_drum.drum.utils.setup import setup_options from datarobot_drum.drum.enum import RunMode diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index d2a6b3d5e..841efb70d 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -295,15 +295,46 @@ def handle_exception(e): def _run_flask_app(self, app): host = self._params.get("host", None) port = self._params.get("port", None) - + server_type = self._params.get("server_type", "flask") processes = 1 if self._params.get("processes"): processes = self._params.get("processes") logger.info("Number of webserver processes: %s", processes) try: - app.run(host, port, threaded=False, processes=processes) + if True: + try: + from gunicorn.app.base import BaseApplication + except ImportError: + BaseApplication = None + raise DrumCommonException("gunicorn is not installed. Please install gunicorn.") + class GunicornApp(BaseApplication): + def __init__(self, app, host, port, params): + self.application = app + self.host = host + self.port = port + self.params = params + super().__init__() + def load_config(self): + self.cfg.set("bind", f"{self.host}:{self.port}") + workers = self.params.get("gunicorn_workers") or self.params.get("max_workers") or processes + self.cfg.set("workers", workers) + self.cfg.set("worker_class", self.params.get("gunicorn_worker_class", "sync")) + self.cfg.set("backlog", self.params.get("gunicorn_backlog", 2048)) + self.cfg.set("timeout", self.params.get("gunicorn_timeout", 120)) + self.cfg.set("graceful_timeout", self.params.get("gunicorn_graceful_timeout", 30)) + self.cfg.set("keepalive", self.params.get("gunicorn_keep_alive", 5)) + self.cfg.set("max_requests", self.params.get("gunicorn_max_requests", 2000)) + self.cfg.set("max_requests_jitter", self.params.get("gunicorn_max_requests_jitter", 500)) + self.cfg.set("loglevel", self.params.get("gunicorn_log_level", "info")) + # Remove unsupported config keys: access_logfile, error_logfile, access_logformat + # These must be set via CLI, not config API + def load(self): + return self.application + GunicornApp(app, host, port, self._params).run() + else: + app.run(host, port, threaded=False, processes=processes) except OSError as e: - raise DrumCommonException("{}: host: {}; port: {}".format(e, host, port)) + raise DrumCommonException(f"{e}: host: {host}; port: {port}") def terminate(self): terminate_op = getattr(self._predictor, "terminate", None) diff --git a/custom_model_runner/requirements.txt b/custom_model_runner/requirements.txt index db2466b50..d73b3b8b4 100644 --- a/custom_model_runner/requirements.txt +++ b/custom_model_runner/requirements.txt @@ -3,6 +3,8 @@ argcomplete trafaret>=2.0.0 docker>=4.2.2 flask +gevent>=22.10.2 +gunicorn>=20.1.0 jinja2>=3.0.0 memory_profiler<1.0.0 numpy From ed4404f9c55dfc94030263dd9a99835370c59f69 Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Fri, 29 Aug 2025 12:35:53 -0400 Subject: [PATCH 02/13] fixed logs --- .../datarobot_drum/drum/root_predictors/prediction_server.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index 841efb70d..a1f4e37b7 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -326,6 +326,9 @@ def load_config(self): self.cfg.set("max_requests", self.params.get("gunicorn_max_requests", 2000)) self.cfg.set("max_requests_jitter", self.params.get("gunicorn_max_requests_jitter", 500)) self.cfg.set("loglevel", self.params.get("gunicorn_log_level", "info")) + self.cfg.set('accesslog', '-') + self.cfg.set('errorlog', '-') # if you want error logs to stdout + self.cfg.set('access_log_format', '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"') # Remove unsupported config keys: access_logfile, error_logfile, access_logformat # These must be set via CLI, not config API def load(self): From 1b6ff68852a007530785ab2f3fba8803783b45b5 Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Fri, 29 Aug 2025 13:59:05 -0400 Subject: [PATCH 03/13] fix run --- .../drum/root_predictors/prediction_server.py | 2 +- .../pipelines/prediction_server_pipeline.json.j2 | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index a1f4e37b7..d9adebbc8 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -328,7 +328,7 @@ def load_config(self): self.cfg.set("loglevel", self.params.get("gunicorn_log_level", "info")) self.cfg.set('accesslog', '-') self.cfg.set('errorlog', '-') # if you want error logs to stdout - self.cfg.set('access_log_format', '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"') + self.cfg.set('access_log_format', '%(t)s %(h)s %(l)s %(u)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"') # Remove unsupported config keys: access_logfile, error_logfile, access_logformat # These must be set via CLI, not config API def load(self): diff --git a/custom_model_runner/datarobot_drum/resource/pipelines/prediction_server_pipeline.json.j2 b/custom_model_runner/datarobot_drum/resource/pipelines/prediction_server_pipeline.json.j2 index fbd753997..6eee2ca09 100644 --- a/custom_model_runner/datarobot_drum/resource/pipelines/prediction_server_pipeline.json.j2 +++ b/custom_model_runner/datarobot_drum/resource/pipelines/prediction_server_pipeline.json.j2 @@ -20,6 +20,7 @@ "processes": {{ processes }}, "monitor": "{{ monitor }}", "monitor_embedded": "{{ monitor_embedded }}", + "server_type": "{{ server_type }}", "model_id": "{{ model_id }}", "deployment_id": {{ deployment_id | jsonify }}, "monitor_settings": {{ monitor_settings | jsonify }}, @@ -33,7 +34,16 @@ "sidecar": {{ sidecar | jsonify }}, "triton_host": {{ triton_host | jsonify }}, "triton_http_port": {{ triton_http_port | jsonify }}, - "triton_grpc_port": {{ triton_grpc_port | jsonify }} + "triton_grpc_port": {{ triton_grpc_port | jsonify }}, + "gunicorn_backlog": {{ gunicorn_backlog }}, + "gunicorn_timeout": {{ gunicorn_timeout }}, + "gunicorn_graceful_timeout": {{ gunicorn_graceful_timeout }}, + "gunicorn_keep_alive": {{ gunicorn_keep_alive }}, + "gunicorn_max_requests": {{ gunicorn_max_requests }}, + "gunicorn_max_requests_jitter": {{ gunicorn_max_requests_jitter }}, + "gunicorn_log_level": "{{ gunicorn_log_level }}", + "gunicorn_workers": {{ gunicorn_workers }}, + "gunicorn_worker_class": "{{ gunicorn_worker_class }}" } } ] From 850bd9c4b923659dc51c9ce1f999ae76cf34d969 Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Fri, 29 Aug 2025 15:18:19 -0400 Subject: [PATCH 04/13] add runtime for flask --- .../drum/root_predictors/prediction_server.py | 98 ++++++++++++++++--- 1 file changed, 85 insertions(+), 13 deletions(-) diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index d9adebbc8..6ba0514e0 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -13,6 +13,8 @@ from werkzeug.exceptions import HTTPException from opentelemetry import trace + +from datarobot_drum import RuntimeParameters from datarobot_drum.drum.description import version as drum_version from datarobot_drum.drum.enum import ( FLASK_EXT_FILE_NAME, @@ -292,40 +294,108 @@ def handle_exception(e): return [] + def get_gunicorn_config(self): + config = {} + if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CLASS"): + worker_class = str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS")) + if worker_class.lower() in {"sync", "gevent"}: + config["worker_class"] = worker_class + + if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CONNECTIONS"): + worker_connections = int(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CONNECTIONS")) + if 1 <= worker_connections <= 10000: + config["worker_connections"] = worker_connections + + if RuntimeParameters.has("DRUM_GUNICORN_BACKLOG"): + backlog = int(RuntimeParameters.get("DRUM_GUNICORN_BACKLOG")) + if 1 <= backlog <= 2048: + config["backlog"] = backlog + + if RuntimeParameters.has("DRUM_GUNICORN_TIMEOUT"): + timeout = int(RuntimeParameters.get("DRUM_GUNICORN_TIMEOUT")) + if 1 <= timeout <= 3600: + config["timeout"] = timeout + + if RuntimeParameters.has("DRUM_GUNICORN_GRACEFUL_TIMEOUT"): + graceful_timeout = int(RuntimeParameters.get("DRUM_GUNICORN_GRACEFUL_TIMEOUT")) + if 1 <= graceful_timeout <= 3600: + config["graceful_timeout"] = graceful_timeout + + if RuntimeParameters.has("DRUM_GUNICORN_KEEP_ALIVE"): + keepalive = int(RuntimeParameters.get("DRUM_GUNICORN_KEEP_ALIVE")) + if 1 <= keepalive <= 3600: + config["keepalive"] = keepalive + + if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS"): + max_requests = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS")) + if 100 <= max_requests <= 10000: + config["max_requests"] = max_requests + + if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS_JITTER"): + max_requests_jitter = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS_JITTER")) + if 1 <= max_requests_jitter <= 10000: + config["max_requests_jitter"] = max_requests_jitter + + if RuntimeParameters.has("DRUM_GUNICORN_LOG_LEVEL"): + loglevel = str(RuntimeParameters.get("DRUM_GUNICORN_LOG_LEVEL")) + if loglevel.lower() in {"debug", "info", "warning", "error", "critical"}: + config["loglevel"] = loglevel + + return config + + def get_server_type(self): + server_type = "flask" + if RuntimeParameters.has("DRUM_SERVER_TYPE"): + server_type = str(RuntimeParameters.get("DRUM_SERVER_TYPE")) + if server_type.lower() in {"flask", "gunicorn"}: + server_type = server_type.lower() + return server_type + + def _run_flask_app(self, app): host = self._params.get("host", None) port = self._params.get("port", None) - server_type = self._params.get("server_type", "flask") + server_type = self.get_server_type() processes = 1 if self._params.get("processes"): processes = self._params.get("processes") logger.info("Number of webserver processes: %s", processes) try: - if True: + logger.info(self._params) + if server_type == "gunicorn": + logger.info("Starting gunicorn server") try: from gunicorn.app.base import BaseApplication except ImportError: BaseApplication = None raise DrumCommonException("gunicorn is not installed. Please install gunicorn.") class GunicornApp(BaseApplication): - def __init__(self, app, host, port, params): + def __init__(self, app, host, port, params, gunicorn_config): self.application = app self.host = host self.port = port self.params = params + self.gunicorn_config = gunicorn_config super().__init__() def load_config(self): + self.cfg.set("bind", f"{self.host}:{self.port}") - workers = self.params.get("gunicorn_workers") or self.params.get("max_workers") or processes + workers = self.params.get("gunicorn_workers") or self.params.get("max_workers") or self.params.get("processes") self.cfg.set("workers", workers) - self.cfg.set("worker_class", self.params.get("gunicorn_worker_class", "sync")) - self.cfg.set("backlog", self.params.get("gunicorn_backlog", 2048)) - self.cfg.set("timeout", self.params.get("gunicorn_timeout", 120)) - self.cfg.set("graceful_timeout", self.params.get("gunicorn_graceful_timeout", 30)) - self.cfg.set("keepalive", self.params.get("gunicorn_keep_alive", 5)) - self.cfg.set("max_requests", self.params.get("gunicorn_max_requests", 2000)) - self.cfg.set("max_requests_jitter", self.params.get("gunicorn_max_requests_jitter", 500)) - self.cfg.set("loglevel", self.params.get("gunicorn_log_level", "info")) + + self.cfg.set("worker_class", self.gunicorn_config.get("worker_class", "sync")) + self.cfg.set("backlog", self.gunicorn_config.get("backlog", 2048)) + self.cfg.set("timeout", self.gunicorn_config.get("timeout", 120)) + self.cfg.set("graceful_timeout", self.gunicorn_config.get("graceful_timeout", 30)) + self.cfg.set("keepalive", self.gunicorn_config.get("keepalive", 5)) + self.cfg.set("max_requests", self.gunicorn_config.get("max_requests", 2000)) + self.cfg.set("max_requests_jitter", self.gunicorn_config.get("max_requests_jitter", 500)) + + if self.gunicorn_config.get("worker_connections"): + self.cfg.set("worker_connections", self.gunicorn_config.get("worker_connections")) + self.cfg.set("loglevel", self.gunicorn_config.get("loglevel", "info")) + + self.cfg.set('accesslog', '-') self.cfg.set('errorlog', '-') # if you want error logs to stdout self.cfg.set('access_log_format', '%(t)s %(h)s %(l)s %(u)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"') @@ -333,7 +403,9 @@ def load_config(self): # These must be set via CLI, not config API def load(self): return self.application - GunicornApp(app, host, port, self._params).run() + + gunicorn_config = self.get_gunicorn_config() + GunicornApp(app, host, port, self._params, gunicorn_config).run() else: app.run(host, port, threaded=False, processes=processes) except OSError as e: From 188db51ecc922e8e3b9c54eecc9ef463665a91b6 Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Fri, 29 Aug 2025 15:36:36 -0400 Subject: [PATCH 05/13] removed args --- .../datarobot_drum/drum/args_parser.py | 34 ------------------- .../datarobot_drum/drum/drum.py | 14 -------- .../prediction_server_pipeline.json.j2 | 12 +------ 3 files changed, 1 insertion(+), 59 deletions(-) diff --git a/custom_model_runner/datarobot_drum/drum/args_parser.py b/custom_model_runner/datarobot_drum/drum/args_parser.py index 864869feb..8860e4218 100644 --- a/custom_model_runner/datarobot_drum/drum/args_parser.py +++ b/custom_model_runner/datarobot_drum/drum/args_parser.py @@ -122,16 +122,6 @@ def _reg_args_lazy_loading_file(*parsers): help="Path to a lazy loading values file. (env: LAZY_LOADING_FILE)", ) - @staticmethod - def _reg_arg_server_type(*parsers): - for parser in parsers: - parser.add_argument( - '--server-type', - type=str, - default=None, - help='Type of server to run (optional, string)' - ) - @staticmethod def _reg_arg_output(*parsers): for parser in parsers: @@ -734,27 +724,6 @@ def _reg_arg_triton_server_access(*parsers): help="NVIDIA Triton Inference Server GRPC port", ) - @staticmethod - def _reg_arg_gunicorn_options(*parsers): - for parser in parsers: - parser.add_argument('--gunicorn-backlog', type=int, default=2024, help='Gunicorn backlog') - parser.add_argument('--gunicorn-timeout', type=int, default=120, help='Gunicorn timeout') - parser.add_argument('--gunicorn-graceful-timeout', type=int, default=30, help='Gunicorn graceful timeout') - parser.add_argument('--gunicorn-keep-alive', type=int, default=5, help='Gunicorn keep alive') - parser.add_argument('--gunicorn-max-requests', type=int, default=2000, help='Gunicorn max requests') - parser.add_argument('--gunicorn-max-requests-jitter', type=int, default=500, - help='Gunicorn max requests jitter') - parser.add_argument('--gunicorn-log-level', type=str, default='info', help='Gunicorn log level') - parser.add_argument('--gunicorn-access-logfile', type=str, default='-', help='Gunicorn access logfile') - parser.add_argument('--gunicorn-error-logfile', type=str, default='-', help='Gunicorn error logfile') - parser.add_argument('--gunicorn-access-logformat', type=str, - default='%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"', - help='Gunicorn access log format') - parser.add_argument('--gunicorn-workers', type=int, default=None, - help='Gunicorn number of workers (overrides --max-workers)') - parser.add_argument('--gunicorn-worker-class', type=str, default=None, - help='Gunicorn worker class (e.g. sync, gevent, eventlet, etc.)') - @staticmethod def _register_subcommand_perf_test(subparsers): desc = """ @@ -1057,9 +1026,6 @@ def get_arg_parser(): score_parser, server_parser, validation_parser ) - CMRunnerArgsRegistry._reg_arg_server_type(server_parser) - CMRunnerArgsRegistry._reg_arg_gunicorn_options(server_parser) - return parser @staticmethod diff --git a/custom_model_runner/datarobot_drum/drum/drum.py b/custom_model_runner/datarobot_drum/drum/drum.py index efb3a6887..37df6da5e 100644 --- a/custom_model_runner/datarobot_drum/drum/drum.py +++ b/custom_model_runner/datarobot_drum/drum/drum.py @@ -771,20 +771,6 @@ def _prepare_prediction_server_or_batch_pipeline(self, run_language): "target_type": self.target_type.value, "user_secrets_mount_path": getattr(options, "user_secrets_mount_path", None), "user_secrets_prefix": getattr(options, "user_secrets_prefix", None), - "server_type": getattr(options, "server_type", None), - # Gunicorn options - "gunicorn_backlog": getattr(options, "gunicorn_backlog", None), - "gunicorn_timeout": getattr(options, "gunicorn_timeout", None), - "gunicorn_graceful_timeout": getattr(options, "gunicorn_graceful_timeout", None), - "gunicorn_keep_alive": getattr(options, "gunicorn_keep_alive", None), - "gunicorn_max_requests": getattr(options, "gunicorn_max_requests", None), - "gunicorn_max_requests_jitter": getattr(options, "gunicorn_max_requests_jitter", None), - "gunicorn_log_level": getattr(options, "gunicorn_log_level", None), - "gunicorn_access_logfile": getattr(options, "gunicorn_access_logfile", None), - "gunicorn_error_logfile": getattr(options, "gunicorn_error_logfile", None), - "gunicorn_access_logformat": getattr(options, "gunicorn_access_logformat", None), - "gunicorn_workers": getattr(options, "gunicorn_workers", None), - "gunicorn_worker_class": getattr(options, "gunicorn_worker_class", None), } if self.run_mode == RunMode.SCORE: diff --git a/custom_model_runner/datarobot_drum/resource/pipelines/prediction_server_pipeline.json.j2 b/custom_model_runner/datarobot_drum/resource/pipelines/prediction_server_pipeline.json.j2 index 6eee2ca09..fbd753997 100644 --- a/custom_model_runner/datarobot_drum/resource/pipelines/prediction_server_pipeline.json.j2 +++ b/custom_model_runner/datarobot_drum/resource/pipelines/prediction_server_pipeline.json.j2 @@ -20,7 +20,6 @@ "processes": {{ processes }}, "monitor": "{{ monitor }}", "monitor_embedded": "{{ monitor_embedded }}", - "server_type": "{{ server_type }}", "model_id": "{{ model_id }}", "deployment_id": {{ deployment_id | jsonify }}, "monitor_settings": {{ monitor_settings | jsonify }}, @@ -34,16 +33,7 @@ "sidecar": {{ sidecar | jsonify }}, "triton_host": {{ triton_host | jsonify }}, "triton_http_port": {{ triton_http_port | jsonify }}, - "triton_grpc_port": {{ triton_grpc_port | jsonify }}, - "gunicorn_backlog": {{ gunicorn_backlog }}, - "gunicorn_timeout": {{ gunicorn_timeout }}, - "gunicorn_graceful_timeout": {{ gunicorn_graceful_timeout }}, - "gunicorn_keep_alive": {{ gunicorn_keep_alive }}, - "gunicorn_max_requests": {{ gunicorn_max_requests }}, - "gunicorn_max_requests_jitter": {{ gunicorn_max_requests_jitter }}, - "gunicorn_log_level": "{{ gunicorn_log_level }}", - "gunicorn_workers": {{ gunicorn_workers }}, - "gunicorn_worker_class": "{{ gunicorn_worker_class }}" + "triton_grpc_port": {{ triton_grpc_port | jsonify }} } } ] From 3c872bda1121a4cfffd0160eefc8e461d59d3b87 Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Fri, 29 Aug 2025 15:36:48 -0400 Subject: [PATCH 06/13] fix main --- custom_model_runner/datarobot_drum/drum/main.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/custom_model_runner/datarobot_drum/drum/main.py b/custom_model_runner/datarobot_drum/drum/main.py index b5775818a..5a316054b 100644 --- a/custom_model_runner/datarobot_drum/drum/main.py +++ b/custom_model_runner/datarobot_drum/drum/main.py @@ -42,17 +42,16 @@ import os import signal import sys +from datarobot_drum import RuntimeParameters # Monkey patching for gevent compatibility if running with gunicorn-gevent -if ( - "gunicorn-gevent" in sys.argv or os.environ.get("SERVER_TYPE") == "gunicorn-gevent" -): - try: - from gevent import monkey - - monkey.patch_all() - except ImportError: - pass +if RuntimeParameters.has("DRUM_SERVER_TYPE") and RuntimeParameters.has("DRUM_GUNICORN_WORKER_CLASS"): + if str(RuntimeParameters.has("DRUM_SERVER_TYPE")).lower() == "gunicorn" and str(RuntimeParameters.get("DRUM_SERVER_TYPE")).lower() == "gevent": + try: + from gevent import monkey + monkey.patch_all() + except ImportError: + pass from datarobot_drum.drum.common import config_logging, setup_otel from datarobot_drum.drum.utils.setup import setup_options From 651d7fc93e923e4b891679c467ae40b0520fce7f Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Fri, 29 Aug 2025 15:42:03 -0400 Subject: [PATCH 07/13] lint formatting --- .../datarobot_drum/drum/main.py | 10 ++++- .../drum/root_predictors/prediction_server.py | 41 +++++++++++++------ 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/custom_model_runner/datarobot_drum/drum/main.py b/custom_model_runner/datarobot_drum/drum/main.py index 5a316054b..167ffe867 100644 --- a/custom_model_runner/datarobot_drum/drum/main.py +++ b/custom_model_runner/datarobot_drum/drum/main.py @@ -45,10 +45,16 @@ from datarobot_drum import RuntimeParameters # Monkey patching for gevent compatibility if running with gunicorn-gevent -if RuntimeParameters.has("DRUM_SERVER_TYPE") and RuntimeParameters.has("DRUM_GUNICORN_WORKER_CLASS"): - if str(RuntimeParameters.has("DRUM_SERVER_TYPE")).lower() == "gunicorn" and str(RuntimeParameters.get("DRUM_SERVER_TYPE")).lower() == "gevent": +if RuntimeParameters.has("DRUM_SERVER_TYPE") and RuntimeParameters.has( + "DRUM_GUNICORN_WORKER_CLASS" +): + if ( + str(RuntimeParameters.has("DRUM_SERVER_TYPE")).lower() == "gunicorn" + and str(RuntimeParameters.get("DRUM_SERVER_TYPE")).lower() == "gevent" + ): try: from gevent import monkey + monkey.patch_all() except ImportError: pass diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index 7bf2d09cf..53e5e757b 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -376,7 +376,6 @@ def get_server_type(self): server_type = server_type.lower() return server_type - def _run_flask_app(self, app): host = self._params.get("host", None) port = self._params.get("port", None) @@ -387,7 +386,7 @@ def _run_flask_app(self, app): logger.info("Number of webserver processes: %s", processes) try: if RuntimeParameters.has("USE_NIM_WATCHDOG") and str( - RuntimeParameters.get("USE_NIM_WATCHDOG") + RuntimeParameters.get("USE_NIM_WATCHDOG") ).lower() in ["true", "1", "yes"]: # Start the watchdog thread before running the app self._server_watchdog = Thread( @@ -405,6 +404,7 @@ def _run_flask_app(self, app): except ImportError: BaseApplication = None raise DrumCommonException("gunicorn is not installed. Please install gunicorn.") + class GunicornApp(BaseApplication): def __init__(self, app, host, port, params, gunicorn_config): self.application = app @@ -413,37 +413,52 @@ def __init__(self, app, host, port, params, gunicorn_config): self.params = params self.gunicorn_config = gunicorn_config super().__init__() - def load_config(self): + def load_config(self): self.cfg.set("bind", f"{self.host}:{self.port}") - workers = self.params.get("gunicorn_workers") or self.params.get("max_workers") or self.params.get("processes") + workers = ( + self.params.get("gunicorn_workers") + or self.params.get("max_workers") + or self.params.get("processes") + ) self.cfg.set("workers", workers) - self.cfg.set("worker_class", self.gunicorn_config.get("worker_class", "sync")) + self.cfg.set( + "worker_class", self.gunicorn_config.get("worker_class", "sync") + ) self.cfg.set("backlog", self.gunicorn_config.get("backlog", 2048)) self.cfg.set("timeout", self.gunicorn_config.get("timeout", 120)) - self.cfg.set("graceful_timeout", self.gunicorn_config.get("graceful_timeout", 30)) + self.cfg.set( + "graceful_timeout", self.gunicorn_config.get("graceful_timeout", 30) + ) self.cfg.set("keepalive", self.gunicorn_config.get("keepalive", 5)) self.cfg.set("max_requests", self.gunicorn_config.get("max_requests", 2000)) - self.cfg.set("max_requests_jitter", self.gunicorn_config.get("max_requests_jitter", 500)) + self.cfg.set( + "max_requests_jitter", + self.gunicorn_config.get("max_requests_jitter", 500), + ) if self.gunicorn_config.get("worker_connections"): - self.cfg.set("worker_connections", self.gunicorn_config.get("worker_connections")) + self.cfg.set( + "worker_connections", self.gunicorn_config.get("worker_connections") + ) self.cfg.set("loglevel", self.gunicorn_config.get("loglevel", "info")) - - self.cfg.set('accesslog', '-') - self.cfg.set('errorlog', '-') # if you want error logs to stdout - self.cfg.set('access_log_format', '%(t)s %(h)s %(l)s %(u)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"') + self.cfg.set("accesslog", "-") + self.cfg.set("errorlog", "-") # if you want error logs to stdout + self.cfg.set( + "access_log_format", + '%(t)s %(h)s %(l)s %(u)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"', + ) # Remove unsupported config keys: access_logfile, error_logfile, access_logformat # These must be set via CLI, not config API + def load(self): return self.application gunicorn_config = self.get_gunicorn_config() GunicornApp(app, host, port, self._params, gunicorn_config).run() else: - # Configure the server with timeout settings app.run( host=host, From 2a69a456eeef1f55dece8e3211c5fa4411c7ac03 Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Fri, 29 Aug 2025 15:44:34 -0400 Subject: [PATCH 08/13] worker class --- custom_model_runner/datarobot_drum/drum/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/custom_model_runner/datarobot_drum/drum/main.py b/custom_model_runner/datarobot_drum/drum/main.py index 167ffe867..938d22391 100644 --- a/custom_model_runner/datarobot_drum/drum/main.py +++ b/custom_model_runner/datarobot_drum/drum/main.py @@ -50,7 +50,7 @@ ): if ( str(RuntimeParameters.has("DRUM_SERVER_TYPE")).lower() == "gunicorn" - and str(RuntimeParameters.get("DRUM_SERVER_TYPE")).lower() == "gevent" + and str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS")).lower() == "gevent" ): try: from gevent import monkey From 5a65d315697c5ab29f51cbcd53ce5f7c4cb035ee Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Fri, 29 Aug 2025 16:31:49 -0400 Subject: [PATCH 09/13] fixed monkey --- .../datarobot_drum/drum/main.py | 32 +++++++++---------- .../drum/root_predictors/prediction_server.py | 1 + 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/custom_model_runner/datarobot_drum/drum/main.py b/custom_model_runner/datarobot_drum/drum/main.py index 938d22391..852d80109 100644 --- a/custom_model_runner/datarobot_drum/drum/main.py +++ b/custom_model_runner/datarobot_drum/drum/main.py @@ -4,6 +4,21 @@ This is proprietary source code of DataRobot, Inc. and its affiliates. Released under the terms of DataRobot Tool and Utility Agreement. """ +from ..runtime_parameters import RuntimeParameters + +# Monkey patching for gevent compatibility if running with gunicorn-gevent +if RuntimeParameters.has("DRUM_SERVER_TYPE") and RuntimeParameters.has( + "DRUM_GUNICORN_WORKER_CLASS" +): + if ( + str(RuntimeParameters.get("DRUM_SERVER_TYPE")).lower() == "gunicorn" + and str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS")).lower() == "gevent" + ): + try: + from gevent import monkey + monkey.patch_all() + except ImportError: + pass from datarobot_drum.drum.lazy_loading.lazy_loading_handler import LazyLoadingHandler @@ -39,25 +54,10 @@ # Run regression user model in fit mode. drum fit --code-dir --input --output --target-type regression --target --verbose """ + import os import signal import sys -from datarobot_drum import RuntimeParameters - -# Monkey patching for gevent compatibility if running with gunicorn-gevent -if RuntimeParameters.has("DRUM_SERVER_TYPE") and RuntimeParameters.has( - "DRUM_GUNICORN_WORKER_CLASS" -): - if ( - str(RuntimeParameters.has("DRUM_SERVER_TYPE")).lower() == "gunicorn" - and str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS")).lower() == "gevent" - ): - try: - from gevent import monkey - - monkey.patch_all() - except ImportError: - pass from datarobot_drum.drum.common import config_logging, setup_otel from datarobot_drum.drum.utils.setup import setup_options diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index 53e5e757b..29c4bfe01 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -415,6 +415,7 @@ def __init__(self, app, host, port, params, gunicorn_config): super().__init__() def load_config(self): + self.cfg.set("bind", f"{self.host}:{self.port}") workers = ( self.params.get("gunicorn_workers") From cf4aad47c2e62f8ea9afaa2fd09fcb19563130f4 Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Fri, 29 Aug 2025 16:32:04 -0400 Subject: [PATCH 10/13] fixed monkey --- custom_model_runner/datarobot_drum/drum/main.py | 1 + .../datarobot_drum/drum/root_predictors/prediction_server.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/custom_model_runner/datarobot_drum/drum/main.py b/custom_model_runner/datarobot_drum/drum/main.py index 852d80109..349a3a8b8 100644 --- a/custom_model_runner/datarobot_drum/drum/main.py +++ b/custom_model_runner/datarobot_drum/drum/main.py @@ -16,6 +16,7 @@ ): try: from gevent import monkey + monkey.patch_all() except ImportError: pass diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index 29c4bfe01..53e5e757b 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -415,7 +415,6 @@ def __init__(self, app, host, port, params, gunicorn_config): super().__init__() def load_config(self): - self.cfg.set("bind", f"{self.host}:{self.port}") workers = ( self.params.get("gunicorn_workers") From 87e98408666cb350258a6c6dc81a1d0f79bba6e8 Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Fri, 29 Aug 2025 22:38:20 -0400 Subject: [PATCH 11/13] corrected --- .../drum/root_predictors/prediction_server.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index 53e5e757b..ee521efef 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -366,6 +366,11 @@ def get_gunicorn_config(self): if loglevel.lower() in {"debug", "info", "warning", "error", "critical"}: config["loglevel"] = loglevel + if RuntimeParameters.has("DRUM_GUNICORN_WORKERS"): + workers = int(RuntimeParameters.get("DRUM_GUNICORN_WORKERS")) + if 0 < workers < 200: + config["workers"] = workers + return config def get_server_type(self): @@ -417,10 +422,11 @@ def __init__(self, app, host, port, params, gunicorn_config): def load_config(self): self.cfg.set("bind", f"{self.host}:{self.port}") workers = ( - self.params.get("gunicorn_workers") - or self.params.get("max_workers") + self.params.get("max_workers") or self.params.get("processes") ) + if self.gunicorn_config.get("workers"): + workers = self.gunicorn_config.get("workers") self.cfg.set("workers", workers) self.cfg.set( @@ -429,7 +435,7 @@ def load_config(self): self.cfg.set("backlog", self.gunicorn_config.get("backlog", 2048)) self.cfg.set("timeout", self.gunicorn_config.get("timeout", 120)) self.cfg.set( - "graceful_timeout", self.gunicorn_config.get("graceful_timeout", 30) + "graceful_timeout", self.gunicorn_config.get("graceful_timeout", 60) ) self.cfg.set("keepalive", self.gunicorn_config.get("keepalive", 5)) self.cfg.set("max_requests", self.gunicorn_config.get("max_requests", 2000)) @@ -444,12 +450,12 @@ def load_config(self): ) self.cfg.set("loglevel", self.gunicorn_config.get("loglevel", "info")) - self.cfg.set("accesslog", "-") + '''self.cfg.set("accesslog", "-") self.cfg.set("errorlog", "-") # if you want error logs to stdout self.cfg.set( "access_log_format", '%(t)s %(h)s %(l)s %(u)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"', - ) + )''' # Remove unsupported config keys: access_logfile, error_logfile, access_logformat # These must be set via CLI, not config API From f5431230150d47a7b2b3495a0ead072b72ad6ef9 Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Fri, 29 Aug 2025 22:38:37 -0400 Subject: [PATCH 12/13] add file --- .../root_predictors/gevent_stdout_flusher.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 custom_model_runner/datarobot_drum/drum/root_predictors/gevent_stdout_flusher.py diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/gevent_stdout_flusher.py b/custom_model_runner/datarobot_drum/drum/root_predictors/gevent_stdout_flusher.py new file mode 100644 index 000000000..b0df1ba0d --- /dev/null +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/gevent_stdout_flusher.py @@ -0,0 +1,93 @@ +import sys +import time +try: + import gevent + from gevent import Greenlet + HAS_GEVENT = True +except ImportError: + HAS_GEVENT = False + import threading + +HAS_GEVENT = True + + +class GeventCompatibleStdoutFlusher: + """An implementation to flush the stdout after a certain time of no activity. + Compatible with both gevent and threading environments.""" + + def __init__(self, max_time_until_flushing=1.0): + self._max_time_until_flushing = max_time_until_flushing + self._last_predict_time = None + self._flusher_greenlet = None + self._flusher_thread = None + self._stop_event = None + self._running = False + + def start(self): + """Start the stdout flusher.""" + if self._running: + return + + self._running = True + + if HAS_GEVENT: + self._flusher_greenlet = gevent.spawn(self._flush_greenlet_method) + else: + self._stop_event = threading.Event() + self._flusher_thread = threading.Thread(target=self._flush_thread_method) + """Check if the stdout flusher is alive.""" + if HAS_GEVENT and self._flusher_greenlet: + return not self._flusher_greenlet.dead + elif self._flusher_thread: + return self._flusher_thread.is_alive() + return False + + def stop(self): + """Stop the flusher in a synchronous fashion.""" + if not self._running: + return + + self._running = False + + if HAS_GEVENT and self._flusher_greenlet: + self._flusher_greenlet.kill() + self._flusher_greenlet = None + elif self._flusher_thread and self._stop_event: + self._stop_event.set() + self._flusher_thread.join(timeout=2.0) # Timeout to prevent hanging + self._flusher_thread = None + self._stop_event = None + + def set_last_activity_time(self): + """Set the last activity time that will be used as the reference for time comparison.""" + self._last_predict_time = self._current_time() + + @staticmethod + def _current_time(): + return time.time() + + def _flush_greenlet_method(self): + """Gevent greenlet method for stdout flushing""" + try: + while self._running: + self._process_stdout_flushing() + #gevent.sleep(self._max_time_until_flushing) + gevent.sleep(0) + except gevent.GreenletExit: + pass # Normal termination + + def _flush_thread_method(self): + """Threading method for stdout flushing""" + while self._running and not self._stop_event.wait(self._max_time_until_flushing): + self._process_stdout_flushing() + + def _process_stdout_flushing(self): + if self._is_predict_time_set_and_max_waiting_time_expired(): + sys.stdout.flush() + sys.stderr.flush() + + def _is_predict_time_set_and_max_waiting_time_expired(self): + if self._last_predict_time is None: + return False + + return (self._current_time() - self._last_predict_time) >= self._max_time_until_flushing From 799bbc466af1970a8e9500749c1dfaeef009da02 Mon Sep 17 00:00:00 2001 From: Sergey Gavrenkov Date: Sun, 31 Aug 2025 00:16:35 -0400 Subject: [PATCH 13/13] add self._terminate --- .../root_predictors/gevent_stdout_flusher.py | 7 +- .../drum/root_predictors/prediction_server.py | 73 +++++++++++++++++-- custom_model_runner/requirements.txt | 4 +- 3 files changed, 75 insertions(+), 9 deletions(-) diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/gevent_stdout_flusher.py b/custom_model_runner/datarobot_drum/drum/root_predictors/gevent_stdout_flusher.py index b0df1ba0d..70aebfdaa 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/gevent_stdout_flusher.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/gevent_stdout_flusher.py @@ -1,3 +1,4 @@ +import logging import sys import time try: @@ -10,6 +11,7 @@ HAS_GEVENT = True +logger = logging.getLogger(__name__) class GeventCompatibleStdoutFlusher: """An implementation to flush the stdout after a certain time of no activity. @@ -45,18 +47,21 @@ def start(self): def stop(self): """Stop the flusher in a synchronous fashion.""" if not self._running: + logging.error("Flusher thread stopped 1.") return self._running = False - + logging.error("Flusher thread stopped 2.") if HAS_GEVENT and self._flusher_greenlet: self._flusher_greenlet.kill() self._flusher_greenlet = None elif self._flusher_thread and self._stop_event: + logging.error("Flusher thread stopped 3.") self._stop_event.set() self._flusher_thread.join(timeout=2.0) # Timeout to prevent hanging self._flusher_thread = None self._stop_event = None + logging.error("Flusher thread stopped 4.") def set_last_activity_time(self): """Set the last activity time that will be used as the reference for time comparison.""" diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index ee521efef..467d9a863 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -7,12 +7,14 @@ import logging import os import sys +import threading import time from pathlib import Path from threading import Thread import subprocess import signal +import psutil import requests from flask import Response, jsonify, request from werkzeug.exceptions import HTTPException @@ -312,7 +314,7 @@ def handle_exception(e): app = get_flask_app(model_api) self.load_flask_extensions(app) - self._run_flask_app(app) + self._run_flask_app(app, self._terminate) if self._stats_collector: self._stats_collector.print_reports() @@ -353,7 +355,7 @@ def get_gunicorn_config(self): if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS"): max_requests = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS")) - if 100 <= max_requests <= 10000: + if 1 <= max_requests <= 10000: config["max_requests"] = max_requests if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS_JITTER"): @@ -381,7 +383,7 @@ def get_server_type(self): server_type = server_type.lower() return server_type - def _run_flask_app(self, app): + def _run_flask_app(self, app, termination_hook): host = self._params.get("host", None) port = self._params.get("port", None) server_type = self.get_server_type() @@ -411,12 +413,13 @@ def _run_flask_app(self, app): raise DrumCommonException("gunicorn is not installed. Please install gunicorn.") class GunicornApp(BaseApplication): - def __init__(self, app, host, port, params, gunicorn_config): + def __init__(self, app, host, port, params, gunicorn_config, termination_hook): self.application = app self.host = host self.port = port self.params = params self.gunicorn_config = gunicorn_config + self.termination_hook = termination_hook super().__init__() def load_config(self): @@ -428,6 +431,8 @@ def load_config(self): if self.gunicorn_config.get("workers"): workers = self.gunicorn_config.get("workers") self.cfg.set("workers", workers) + self.cfg.set("reuse_port", True) + self.cfg.set("preload_app", True) self.cfg.set( "worker_class", self.gunicorn_config.get("worker_class", "sync") @@ -438,7 +443,7 @@ def load_config(self): "graceful_timeout", self.gunicorn_config.get("graceful_timeout", 60) ) self.cfg.set("keepalive", self.gunicorn_config.get("keepalive", 5)) - self.cfg.set("max_requests", self.gunicorn_config.get("max_requests", 2000)) + self.cfg.set("max_requests", self.gunicorn_config.get("max_requests", 1000)) self.cfg.set( "max_requests_jitter", self.gunicorn_config.get("max_requests_jitter", 500), @@ -450,6 +455,10 @@ def load_config(self): ) self.cfg.set("loglevel", self.gunicorn_config.get("loglevel", "info")) + # Properly assign the worker_exit hook + if self.termination_hook: + self.cfg.set("worker_exit", self._worker_exit_hook) + '''self.cfg.set("accesslog", "-") self.cfg.set("errorlog", "-") # if you want error logs to stdout self.cfg.set( @@ -462,8 +471,60 @@ def load_config(self): def load(self): return self.application + def _worker_exit_hook(self, server, worker): + pid = worker.pid + server.log.info(f"[HOOK] Worker PID {pid} exiting — running termination hook.") + + def run_hook(): + try: + self.termination_hook() + server.log.info(f"[HOOK] Worker PID {pid} termination logic completed.") + except Exception as e: + server.log.error(f"[HOOK ERROR] Worker PID {pid}: {e}") + + try: + # 🔍 Найти процесс, который слушает тот же порт + port = self.port + occupying_proc = None + + for proc in psutil.process_iter(['pid', 'name']): + try: + for conn in proc.connections(kind='inet'): + if conn.status == psutil.CONN_LISTEN and conn.laddr.port == port: + occupying_proc = proc + break + if occupying_proc: + break + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + if occupying_proc: + server.log.info( + f"[PORT OCCUPIED] Порт {port} занят процессом: PID={occupying_proc.pid}, имя={occupying_proc.name()}" + ) + else: + server.log.info(f"[PORT FREE] Порт {port} свободен.") + + server.log.info(f"[HOOK] Worker PID {pid} termination logic completed.") + except Exception as e: + server.log.error(f"[HOOK ERROR] Worker PID {pid}: {e}") + + + + thread = threading.Thread(target=run_hook) + thread.start() + + + '''for thread in threads: + try: + server.log.info(f"Name: {thread.name}, ID: {thread.ident}, Daemon: {thread.daemon}") + ''' + #server.log.info(f"Active thread count:", threading.active_count()) + thread.join(timeout=20) + server.log.info(f"[HOOK] Worker PID {pid} cleanup done or timed out.") + gunicorn_config = self.get_gunicorn_config() - GunicornApp(app, host, port, self._params, gunicorn_config).run() + GunicornApp(app, host, port, self._params, gunicorn_config, termination_hook).run() else: # Configure the server with timeout settings app.run( diff --git a/custom_model_runner/requirements.txt b/custom_model_runner/requirements.txt index d73b3b8b4..d9b6a68fd 100644 --- a/custom_model_runner/requirements.txt +++ b/custom_model_runner/requirements.txt @@ -3,8 +3,8 @@ argcomplete trafaret>=2.0.0 docker>=4.2.2 flask -gevent>=22.10.2 -gunicorn>=20.1.0 +gevent +gunicorn jinja2>=3.0.0 memory_profiler<1.0.0 numpy