Skip to content

Commit

Permalink
Add environment variable VMARGS (#118)
Browse files Browse the repository at this point in the history
* Add environment variable with VMARGS

* Fix linting error

* Fix flake8

* Fix mock calls in default handler test

* Remove duplicate assert

* Fix function call

* Use pytest instead of py.test

* Change env variable name
  • Loading branch information
nikhil-sk authored Jan 31, 2023
1 parent 3774c1a commit 3821f30
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 4 deletions.
7 changes: 7 additions & 0 deletions src/sagemaker_inference/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DEFAULT_MODEL_SERVER_TIMEOUT = "60"
DEFAULT_STARTUP_TIMEOUT = "600" # 10 minutes
DEFAULT_HTTP_PORT = "8080"
DEFAULT_VMARGS = "-XX:-UseContainerSupport"

SAGEMAKER_BASE_PATH = os.path.join("/opt", "ml") # type: str

Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(self):
self._inference_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT)
self._management_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT)
self._safe_port_range = os.environ.get(parameters.SAFE_PORT_RANGE_ENV)
self._vmargs = os.environ.get(parameters.MODEL_SERVER_VMARGS, DEFAULT_VMARGS)

@staticmethod
def _parse_module_name(program_param):
Expand Down Expand Up @@ -140,3 +142,8 @@ def safe_port_range(self): # type: () -> str
specified by SageMaker for handling pings and invocations.
"""
return self._safe_port_range

@property
def vmargs(self): # type: () -> str
"""str: vmargs can be provided for the JVM, to be overriden"""
return self._vmargs
2 changes: 1 addition & 1 deletion src/sagemaker_inference/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _generate_mms_config_properties(env, handler_service=None):
"default_workers_per_model": env.model_server_workers,
"inference_address": "http://0.0.0.0:{}".format(env.inference_http_port),
"management_address": "http://0.0.0.0:{}".format(env.management_http_port),
"vmargs": "-XX:-UseContainerSupport",
"vmargs": env.vmargs,
}
# If provided, add handler service to user config
if handler_service:
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker_inference/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DEFAULT_INVOCATIONS_ACCEPT_ENV = "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT" # type: str
MODEL_SERVER_WORKERS_ENV = "SAGEMAKER_MODEL_SERVER_WORKERS" # type: str
MODEL_SERVER_TIMEOUT_ENV = "SAGEMAKER_MODEL_SERVER_TIMEOUT" # type: str
MODEL_SERVER_VMARGS = "SAGEMAKER_MODEL_SERVER_VMARGS" # type: str
STARTUP_TIMEOUT_ENV = "SAGEMAKER_STARTUP_TIMEOUT" # type: str
BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str
SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_default_handler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_handle():
result = handler_service.handle(DATA, CONTEXT)

assert result == TRANSFORMED_RESULT
assert transformer.transform.called_once_with(DATA, CONTEXT)
transformer.transform.assert_called_once_with(DATA, CONTEXT)


def test_initialize():
Expand All @@ -57,4 +57,4 @@ def getitem(key):
context.system_properties.__getitem__.side_effect = getitem
DefaultHandlerService(transformer).initialize(context)

assert transformer.validate_and_initialize().called_once()
transformer.validate_and_initialize.assert_called_once()
2 changes: 2 additions & 0 deletions test/unit/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV: "text/html",
parameters.BIND_TO_PORT_ENV: "1738",
parameters.SAFE_PORT_RANGE_ENV: "1111-2222",
parameters.MODEL_SERVER_VMARGS: "-XX:-UseContainerSupport",
},
clear=True,
)
Expand All @@ -45,6 +46,7 @@ def test_env():
assert env.inference_http_port == "1738"
assert env.management_http_port == "1738"
assert env.safe_port_range == "1111-2222"
assert "-XX:-UseContainerSupport" in env.vmargs


@pytest.mark.parametrize("sagemaker_program", ["program.py", "program"])
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ passenv =
# {posargs} can be passed in by additional arguments specified when invoking tox.
# Can be used to specify which tests to run, e.g.: tox -- -s
commands =
coverage run --rcfile .coveragerc_{envname} --source sagemaker_inference -m py.test {posargs}
coverage run --rcfile .coveragerc_{envname} --source sagemaker_inference -m pytest {posargs}
{env:IGNORE_COVERAGE:} coverage report --rcfile .coveragerc_{envname}
{env:IGNORE_COVERAGE:} coverage html --rcfile .coveragerc_{envname}

Expand Down

0 comments on commit 3821f30

Please sign in to comment.