Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions custom_model_runner/datarobot_drum/drum/args_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ def _reg_args_lazy_loading_file(*parsers):
help="Path to a lazy loading values file. (env: LAZY_LOADING_FILE)",
)

@staticmethod
def _reg_arg_server_type(*parsers):
for parser in parsers:
parser.add_argument(
'--server-type',
type=str,
default=None,
help='Type of server to run (optional, string)'
)

@staticmethod
def _reg_arg_output(*parsers):
for parser in parsers:
Expand Down Expand Up @@ -724,6 +734,27 @@ def _reg_arg_triton_server_access(*parsers):
help="NVIDIA Triton Inference Server GRPC port",
)

@staticmethod
def _reg_arg_gunicorn_options(*parsers):
for parser in parsers:
parser.add_argument('--gunicorn-backlog', type=int, default=2024, help='Gunicorn backlog')
parser.add_argument('--gunicorn-timeout', type=int, default=120, help='Gunicorn timeout')
parser.add_argument('--gunicorn-graceful-timeout', type=int, default=30, help='Gunicorn graceful timeout')
parser.add_argument('--gunicorn-keep-alive', type=int, default=5, help='Gunicorn keep alive')
parser.add_argument('--gunicorn-max-requests', type=int, default=2000, help='Gunicorn max requests')
parser.add_argument('--gunicorn-max-requests-jitter', type=int, default=500,
help='Gunicorn max requests jitter')
parser.add_argument('--gunicorn-log-level', type=str, default='info', help='Gunicorn log level')
parser.add_argument('--gunicorn-access-logfile', type=str, default='-', help='Gunicorn access logfile')
parser.add_argument('--gunicorn-error-logfile', type=str, default='-', help='Gunicorn error logfile')
parser.add_argument('--gunicorn-access-logformat', type=str,
default='%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"',
help='Gunicorn access log format')
parser.add_argument('--gunicorn-workers', type=int, default=None,
help='Gunicorn number of workers (overrides --max-workers)')
parser.add_argument('--gunicorn-worker-class', type=str, default=None,
help='Gunicorn worker class (e.g. sync, gevent, eventlet, etc.)')

@staticmethod
def _register_subcommand_perf_test(subparsers):
desc = """
Expand Down Expand Up @@ -1026,6 +1057,9 @@ def get_arg_parser():
score_parser, server_parser, validation_parser
)

CMRunnerArgsRegistry._reg_arg_server_type(server_parser)
CMRunnerArgsRegistry._reg_arg_gunicorn_options(server_parser)

return parser

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion custom_model_runner/datarobot_drum/drum/description.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
This is proprietary source code of DataRobot, Inc. and its affiliates.
Released under the terms of DataRobot Tool and Utility Agreement.
"""
version = "1.16.24"
version = "1.16.25"
__version__ = version
project_name = "datarobot-drum"
14 changes: 14 additions & 0 deletions custom_model_runner/datarobot_drum/drum/drum.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,20 @@ def _prepare_prediction_server_or_batch_pipeline(self, run_language):
"target_type": self.target_type.value,
"user_secrets_mount_path": getattr(options, "user_secrets_mount_path", None),
"user_secrets_prefix": getattr(options, "user_secrets_prefix", None),
"server_type": getattr(options, "server_type", None),
# Gunicorn options
"gunicorn_backlog": getattr(options, "gunicorn_backlog", None),
"gunicorn_timeout": getattr(options, "gunicorn_timeout", None),
"gunicorn_graceful_timeout": getattr(options, "gunicorn_graceful_timeout", None),
"gunicorn_keep_alive": getattr(options, "gunicorn_keep_alive", None),
"gunicorn_max_requests": getattr(options, "gunicorn_max_requests", None),
"gunicorn_max_requests_jitter": getattr(options, "gunicorn_max_requests_jitter", None),
"gunicorn_log_level": getattr(options, "gunicorn_log_level", None),
"gunicorn_access_logfile": getattr(options, "gunicorn_access_logfile", None),
"gunicorn_error_logfile": getattr(options, "gunicorn_error_logfile", None),
"gunicorn_access_logformat": getattr(options, "gunicorn_access_logformat", None),
"gunicorn_workers": getattr(options, "gunicorn_workers", None),
"gunicorn_worker_class": getattr(options, "gunicorn_worker_class", None),
}

if self.run_mode == RunMode.SCORE:
Expand Down
11 changes: 11 additions & 0 deletions custom_model_runner/datarobot_drum/drum/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@
import signal
import sys

# Monkey patching for gevent compatibility if running with gunicorn-gevent
if (
"gunicorn-gevent" in sys.argv or os.environ.get("SERVER_TYPE") == "gunicorn-gevent"
):
try:
from gevent import monkey

monkey.patch_all()
except ImportError:
pass

from datarobot_drum.drum.common import config_logging, setup_otel
from datarobot_drum.drum.utils.setup import setup_options
from datarobot_drum.drum.enum import RunMode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from werkzeug.exceptions import HTTPException

from opentelemetry import trace

from datarobot_drum import RuntimeParameters
from datarobot_drum.drum.description import version as drum_version
from datarobot_drum.drum.enum import (
FLASK_EXT_FILE_NAME,
Expand Down Expand Up @@ -292,18 +294,122 @@ def handle_exception(e):

return []

def get_gunicorn_config(self):
config = {}
if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CLASS"):
worker_class = str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS"))
if worker_class.lower() in {"sync", "gevent"}:
config["worker_class"] = worker_class

if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CONNECTIONS"):
worker_connections = int(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CONNECTIONS"))
if 1 <= worker_connections <= 10000:
config["worker_connections"] = worker_connections

if RuntimeParameters.has("DRUM_GUNICORN_BACKLOG"):
backlog = int(RuntimeParameters.get("DRUM_GUNICORN_BACKLOG"))
if 1 <= backlog <= 2048:
config["backlog"] = backlog

if RuntimeParameters.has("DRUM_GUNICORN_TIMEOUT"):
timeout = int(RuntimeParameters.get("DRUM_GUNICORN_TIMEOUT"))
if 1 <= timeout <= 3600:
config["timeout"] = timeout

if RuntimeParameters.has("DRUM_GUNICORN_GRACEFUL_TIMEOUT"):
graceful_timeout = int(RuntimeParameters.get("DRUM_GUNICORN_GRACEFUL_TIMEOUT"))
if 1 <= graceful_timeout <= 3600:
config["graceful_timeout"] = graceful_timeout

if RuntimeParameters.has("DRUM_GUNICORN_KEEP_ALIVE"):
keepalive = int(RuntimeParameters.get("DRUM_GUNICORN_KEEP_ALIVE"))
if 1 <= keepalive <= 3600:
config["keepalive"] = keepalive

if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS"):
max_requests = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS"))
if 100 <= max_requests <= 10000:
config["max_requests"] = max_requests

if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS_JITTER"):
max_requests_jitter = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS_JITTER"))
if 1 <= max_requests_jitter <= 10000:
config["max_requests_jitter"] = max_requests_jitter

if RuntimeParameters.has("DRUM_GUNICORN_LOG_LEVEL"):
loglevel = str(RuntimeParameters.get("DRUM_GUNICORN_LOG_LEVEL"))
if loglevel.lower() in {"debug", "info", "warning", "error", "critical"}:
config["loglevel"] = loglevel

return config

def get_server_type(self):
server_type = "flask"
if RuntimeParameters.has("DRUM_SERVER_TYPE"):
server_type = str(RuntimeParameters.get("DRUM_SERVER_TYPE"))
if server_type.lower() in {"flask", "gunicorn"}:
server_type = server_type.lower()
return server_type


def _run_flask_app(self, app):
host = self._params.get("host", None)
port = self._params.get("port", None)

server_type = self.get_server_type()
processes = 1
if self._params.get("processes"):
processes = self._params.get("processes")
logger.info("Number of webserver processes: %s", processes)
try:
app.run(host, port, threaded=False, processes=processes)
logger.info(self._params)
if server_type == "gunicorn":
logger.info("Starting gunicorn server")
try:
from gunicorn.app.base import BaseApplication
except ImportError:
BaseApplication = None
raise DrumCommonException("gunicorn is not installed. Please install gunicorn.")
class GunicornApp(BaseApplication):
def __init__(self, app, host, port, params, gunicorn_config):
self.application = app
self.host = host
self.port = port
self.params = params
self.gunicorn_config = gunicorn_config
super().__init__()
def load_config(self):

self.cfg.set("bind", f"{self.host}:{self.port}")
workers = self.params.get("gunicorn_workers") or self.params.get("max_workers") or self.params.get("processes")
self.cfg.set("workers", workers)

self.cfg.set("worker_class", self.gunicorn_config.get("worker_class", "sync"))
self.cfg.set("backlog", self.gunicorn_config.get("backlog", 2048))
self.cfg.set("timeout", self.gunicorn_config.get("timeout", 120))
self.cfg.set("graceful_timeout", self.gunicorn_config.get("graceful_timeout", 30))
self.cfg.set("keepalive", self.gunicorn_config.get("keepalive", 5))
self.cfg.set("max_requests", self.gunicorn_config.get("max_requests", 2000))
self.cfg.set("max_requests_jitter", self.gunicorn_config.get("max_requests_jitter", 500))

if self.gunicorn_config.get("worker_connections"):
self.cfg.set("worker_connections", self.gunicorn_config.get("worker_connections"))
self.cfg.set("loglevel", self.gunicorn_config.get("loglevel", "info"))


self.cfg.set('accesslog', '-')
self.cfg.set('errorlog', '-') # if you want error logs to stdout
self.cfg.set('access_log_format', '%(t)s %(h)s %(l)s %(u)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"')
# Remove unsupported config keys: access_logfile, error_logfile, access_logformat
# These must be set via CLI, not config API
def load(self):
return self.application

gunicorn_config = self.get_gunicorn_config()
GunicornApp(app, host, port, self._params, gunicorn_config).run()
else:
app.run(host, port, threaded=False, processes=processes)
except OSError as e:
raise DrumCommonException("{}: host: {}; port: {}".format(e, host, port))
raise DrumCommonException(f"{e}: host: {host}; port: {port}")

def terminate(self):
terminate_op = getattr(self._predictor, "terminate", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"processes": {{ processes }},
"monitor": "{{ monitor }}",
"monitor_embedded": "{{ monitor_embedded }}",
"server_type": "{{ server_type }}",
"model_id": "{{ model_id }}",
"deployment_id": {{ deployment_id | jsonify }},
"monitor_settings": {{ monitor_settings | jsonify }},
Expand All @@ -33,7 +34,16 @@
"sidecar": {{ sidecar | jsonify }},
"triton_host": {{ triton_host | jsonify }},
"triton_http_port": {{ triton_http_port | jsonify }},
"triton_grpc_port": {{ triton_grpc_port | jsonify }}
"triton_grpc_port": {{ triton_grpc_port | jsonify }},
"gunicorn_backlog": {{ gunicorn_backlog }},
"gunicorn_timeout": {{ gunicorn_timeout }},
"gunicorn_graceful_timeout": {{ gunicorn_graceful_timeout }},
"gunicorn_keep_alive": {{ gunicorn_keep_alive }},
"gunicorn_max_requests": {{ gunicorn_max_requests }},
"gunicorn_max_requests_jitter": {{ gunicorn_max_requests_jitter }},
"gunicorn_log_level": "{{ gunicorn_log_level }}",
"gunicorn_workers": {{ gunicorn_workers }},
"gunicorn_worker_class": "{{ gunicorn_worker_class }}"
}
}
]
Expand Down
2 changes: 2 additions & 0 deletions custom_model_runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ argcomplete
trafaret>=2.0.0
docker>=4.2.2
flask
gevent>=22.10.2
gunicorn>=20.1.0
jinja2>=3.0.0
memory_profiler<1.0.0
numpy
Expand Down