diff --git a/custom_model_runner/datarobot_drum/drum/args_parser.py b/custom_model_runner/datarobot_drum/drum/args_parser.py index 8860e4218..864869feb 100644 --- a/custom_model_runner/datarobot_drum/drum/args_parser.py +++ b/custom_model_runner/datarobot_drum/drum/args_parser.py @@ -122,6 +122,16 @@ def _reg_args_lazy_loading_file(*parsers): help="Path to a lazy loading values file. (env: LAZY_LOADING_FILE)", ) + @staticmethod + def _reg_arg_server_type(*parsers): + for parser in parsers: + parser.add_argument( + '--server-type', + type=str, + default=None, + help='Type of server to run (optional, string)' + ) + @staticmethod def _reg_arg_output(*parsers): for parser in parsers: @@ -724,6 +734,27 @@ def _reg_arg_triton_server_access(*parsers): help="NVIDIA Triton Inference Server GRPC port", ) + @staticmethod + def _reg_arg_gunicorn_options(*parsers): + for parser in parsers: + parser.add_argument('--gunicorn-backlog', type=int, default=2024, help='Gunicorn backlog') + parser.add_argument('--gunicorn-timeout', type=int, default=120, help='Gunicorn timeout') + parser.add_argument('--gunicorn-graceful-timeout', type=int, default=30, help='Gunicorn graceful timeout') + parser.add_argument('--gunicorn-keep-alive', type=int, default=5, help='Gunicorn keep alive') + parser.add_argument('--gunicorn-max-requests', type=int, default=2000, help='Gunicorn max requests') + parser.add_argument('--gunicorn-max-requests-jitter', type=int, default=500, + help='Gunicorn max requests jitter') + parser.add_argument('--gunicorn-log-level', type=str, default='info', help='Gunicorn log level') + parser.add_argument('--gunicorn-access-logfile', type=str, default='-', help='Gunicorn access logfile') + parser.add_argument('--gunicorn-error-logfile', type=str, default='-', help='Gunicorn error logfile') + parser.add_argument('--gunicorn-access-logformat', type=str, + default='%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"', + help='Gunicorn access log format') + parser.add_argument('--gunicorn-workers', type=int, default=None, + help='Gunicorn number of workers (overrides --max-workers)') + parser.add_argument('--gunicorn-worker-class', type=str, default=None, + help='Gunicorn worker class (e.g. sync, gevent, eventlet, etc.)') + @staticmethod def _register_subcommand_perf_test(subparsers): desc = """ @@ -1026,6 +1057,9 @@ def get_arg_parser(): score_parser, server_parser, validation_parser ) + CMRunnerArgsRegistry._reg_arg_server_type(server_parser) + CMRunnerArgsRegistry._reg_arg_gunicorn_options(server_parser) + return parser @staticmethod diff --git a/custom_model_runner/datarobot_drum/drum/description.py b/custom_model_runner/datarobot_drum/drum/description.py index 83a754e6c..8a88317f6 100644 --- a/custom_model_runner/datarobot_drum/drum/description.py +++ b/custom_model_runner/datarobot_drum/drum/description.py @@ -4,6 +4,6 @@ This is proprietary source code of DataRobot, Inc. and its affiliates. Released under the terms of DataRobot Tool and Utility Agreement. """ -version = "1.16.24" +version = "1.16.25" __version__ = version project_name = "datarobot-drum" diff --git a/custom_model_runner/datarobot_drum/drum/drum.py b/custom_model_runner/datarobot_drum/drum/drum.py index 37df6da5e..efb3a6887 100644 --- a/custom_model_runner/datarobot_drum/drum/drum.py +++ b/custom_model_runner/datarobot_drum/drum/drum.py @@ -771,6 +771,20 @@ def _prepare_prediction_server_or_batch_pipeline(self, run_language): "target_type": self.target_type.value, "user_secrets_mount_path": getattr(options, "user_secrets_mount_path", None), "user_secrets_prefix": getattr(options, "user_secrets_prefix", None), + "server_type": getattr(options, "server_type", None), + # Gunicorn options + "gunicorn_backlog": getattr(options, "gunicorn_backlog", None), + "gunicorn_timeout": getattr(options, "gunicorn_timeout", None), + "gunicorn_graceful_timeout": getattr(options, "gunicorn_graceful_timeout", None), + "gunicorn_keep_alive": getattr(options, "gunicorn_keep_alive", None), + "gunicorn_max_requests": getattr(options, "gunicorn_max_requests", None), + "gunicorn_max_requests_jitter": getattr(options, "gunicorn_max_requests_jitter", None), + "gunicorn_log_level": getattr(options, "gunicorn_log_level", None), + "gunicorn_access_logfile": getattr(options, "gunicorn_access_logfile", None), + "gunicorn_error_logfile": getattr(options, "gunicorn_error_logfile", None), + "gunicorn_access_logformat": getattr(options, "gunicorn_access_logformat", None), + "gunicorn_workers": getattr(options, "gunicorn_workers", None), + "gunicorn_worker_class": getattr(options, "gunicorn_worker_class", None), } if self.run_mode == RunMode.SCORE: diff --git a/custom_model_runner/datarobot_drum/drum/main.py b/custom_model_runner/datarobot_drum/drum/main.py index 94aa0e5b5..b5775818a 100644 --- a/custom_model_runner/datarobot_drum/drum/main.py +++ b/custom_model_runner/datarobot_drum/drum/main.py @@ -43,6 +43,17 @@ import signal import sys +# Monkey patching for gevent compatibility if running with gunicorn-gevent +if ( + "gunicorn-gevent" in sys.argv or os.environ.get("SERVER_TYPE") == "gunicorn-gevent" +): + try: + from gevent import monkey + + monkey.patch_all() + except ImportError: + pass + from datarobot_drum.drum.common import config_logging, setup_otel from datarobot_drum.drum.utils.setup import setup_options from datarobot_drum.drum.enum import RunMode diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index d2a6b3d5e..6ba0514e0 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -13,6 +13,8 @@ from werkzeug.exceptions import HTTPException from opentelemetry import trace + +from datarobot_drum import RuntimeParameters from datarobot_drum.drum.description import version as drum_version from datarobot_drum.drum.enum import ( FLASK_EXT_FILE_NAME, @@ -292,18 +294,122 @@ def handle_exception(e): return [] + def get_gunicorn_config(self): + config = {} + if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CLASS"): + worker_class = str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS")) + if worker_class.lower() in {"sync", "gevent"}: + config["worker_class"] = worker_class + + if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CONNECTIONS"): + worker_connections = int(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CONNECTIONS")) + if 1 <= worker_connections <= 10000: + config["worker_connections"] = worker_connections + + if RuntimeParameters.has("DRUM_GUNICORN_BACKLOG"): + backlog = int(RuntimeParameters.get("DRUM_GUNICORN_BACKLOG")) + if 1 <= backlog <= 2048: + config["backlog"] = backlog + + if RuntimeParameters.has("DRUM_GUNICORN_TIMEOUT"): + timeout = int(RuntimeParameters.get("DRUM_GUNICORN_TIMEOUT")) + if 1 <= timeout <= 3600: + config["timeout"] = timeout + + if RuntimeParameters.has("DRUM_GUNICORN_GRACEFUL_TIMEOUT"): + graceful_timeout = int(RuntimeParameters.get("DRUM_GUNICORN_GRACEFUL_TIMEOUT")) + if 1 <= graceful_timeout <= 3600: + config["graceful_timeout"] = graceful_timeout + + if RuntimeParameters.has("DRUM_GUNICORN_KEEP_ALIVE"): + keepalive = int(RuntimeParameters.get("DRUM_GUNICORN_KEEP_ALIVE")) + if 1 <= keepalive <= 3600: + config["keepalive"] = keepalive + + if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS"): + max_requests = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS")) + if 100 <= max_requests <= 10000: + config["max_requests"] = max_requests + + if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS_JITTER"): + max_requests_jitter = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS_JITTER")) + if 1 <= max_requests_jitter <= 10000: + config["max_requests_jitter"] = max_requests_jitter + + if RuntimeParameters.has("DRUM_GUNICORN_LOG_LEVEL"): + loglevel = str(RuntimeParameters.get("DRUM_GUNICORN_LOG_LEVEL")) + if loglevel.lower() in {"debug", "info", "warning", "error", "critical"}: + config["loglevel"] = loglevel + + return config + + def get_server_type(self): + server_type = "flask" + if RuntimeParameters.has("DRUM_SERVER_TYPE"): + server_type = str(RuntimeParameters.get("DRUM_SERVER_TYPE")) + if server_type.lower() in {"flask", "gunicorn"}: + server_type = server_type.lower() + return server_type + + def _run_flask_app(self, app): host = self._params.get("host", None) port = self._params.get("port", None) - + server_type = self.get_server_type() processes = 1 if self._params.get("processes"): processes = self._params.get("processes") logger.info("Number of webserver processes: %s", processes) try: - app.run(host, port, threaded=False, processes=processes) + logger.info(self._params) + if server_type == "gunicorn": + logger.info("Starting gunicorn server") + try: + from gunicorn.app.base import BaseApplication + except ImportError: + BaseApplication = None + raise DrumCommonException("gunicorn is not installed. Please install gunicorn.") + class GunicornApp(BaseApplication): + def __init__(self, app, host, port, params, gunicorn_config): + self.application = app + self.host = host + self.port = port + self.params = params + self.gunicorn_config = gunicorn_config + super().__init__() + def load_config(self): + + self.cfg.set("bind", f"{self.host}:{self.port}") + workers = self.params.get("gunicorn_workers") or self.params.get("max_workers") or self.params.get("processes") + self.cfg.set("workers", workers) + + self.cfg.set("worker_class", self.gunicorn_config.get("worker_class", "sync")) + self.cfg.set("backlog", self.gunicorn_config.get("backlog", 2048)) + self.cfg.set("timeout", self.gunicorn_config.get("timeout", 120)) + self.cfg.set("graceful_timeout", self.gunicorn_config.get("graceful_timeout", 30)) + self.cfg.set("keepalive", self.gunicorn_config.get("keepalive", 5)) + self.cfg.set("max_requests", self.gunicorn_config.get("max_requests", 2000)) + self.cfg.set("max_requests_jitter", self.gunicorn_config.get("max_requests_jitter", 500)) + + if self.gunicorn_config.get("worker_connections"): + self.cfg.set("worker_connections", self.gunicorn_config.get("worker_connections")) + self.cfg.set("loglevel", self.gunicorn_config.get("loglevel", "info")) + + + self.cfg.set('accesslog', '-') + self.cfg.set('errorlog', '-') # if you want error logs to stdout + self.cfg.set('access_log_format', '%(t)s %(h)s %(l)s %(u)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"') + # Remove unsupported config keys: access_logfile, error_logfile, access_logformat + # These must be set via CLI, not config API + def load(self): + return self.application + + gunicorn_config = self.get_gunicorn_config() + GunicornApp(app, host, port, self._params, gunicorn_config).run() + else: + app.run(host, port, threaded=False, processes=processes) except OSError as e: - raise DrumCommonException("{}: host: {}; port: {}".format(e, host, port)) + raise DrumCommonException(f"{e}: host: {host}; port: {port}") def terminate(self): terminate_op = getattr(self._predictor, "terminate", None) diff --git a/custom_model_runner/datarobot_drum/resource/pipelines/prediction_server_pipeline.json.j2 b/custom_model_runner/datarobot_drum/resource/pipelines/prediction_server_pipeline.json.j2 index fbd753997..6eee2ca09 100644 --- a/custom_model_runner/datarobot_drum/resource/pipelines/prediction_server_pipeline.json.j2 +++ b/custom_model_runner/datarobot_drum/resource/pipelines/prediction_server_pipeline.json.j2 @@ -20,6 +20,7 @@ "processes": {{ processes }}, "monitor": "{{ monitor }}", "monitor_embedded": "{{ monitor_embedded }}", + "server_type": "{{ server_type }}", "model_id": "{{ model_id }}", "deployment_id": {{ deployment_id | jsonify }}, "monitor_settings": {{ monitor_settings | jsonify }}, @@ -33,7 +34,16 @@ "sidecar": {{ sidecar | jsonify }}, "triton_host": {{ triton_host | jsonify }}, "triton_http_port": {{ triton_http_port | jsonify }}, - "triton_grpc_port": {{ triton_grpc_port | jsonify }} + "triton_grpc_port": {{ triton_grpc_port | jsonify }}, + "gunicorn_backlog": {{ gunicorn_backlog }}, + "gunicorn_timeout": {{ gunicorn_timeout }}, + "gunicorn_graceful_timeout": {{ gunicorn_graceful_timeout }}, + "gunicorn_keep_alive": {{ gunicorn_keep_alive }}, + "gunicorn_max_requests": {{ gunicorn_max_requests }}, + "gunicorn_max_requests_jitter": {{ gunicorn_max_requests_jitter }}, + "gunicorn_log_level": "{{ gunicorn_log_level }}", + "gunicorn_workers": {{ gunicorn_workers }}, + "gunicorn_worker_class": "{{ gunicorn_worker_class }}" } } ] diff --git a/custom_model_runner/requirements.txt b/custom_model_runner/requirements.txt index db2466b50..d73b3b8b4 100644 --- a/custom_model_runner/requirements.txt +++ b/custom_model_runner/requirements.txt @@ -3,6 +3,8 @@ argcomplete trafaret>=2.0.0 docker>=4.2.2 flask +gevent>=22.10.2 +gunicorn>=20.1.0 jinja2>=3.0.0 memory_profiler<1.0.0 numpy