Skip to content

Commit

Permalink
fix: pass name from modelbuilder constructor to created model
Browse files Browse the repository at this point in the history
  • Loading branch information
gwang111 committed Sep 3, 2024
1 parent 7aa39f9 commit d8952e1
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/sagemaker/serve/builder/djl_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(self):
self.nb_instance_type = None
self.ram_usage_model_load = None
self.role_arn = None
self.name = None

@abstractmethod
def _prepare_for_mode(self):
Expand Down Expand Up @@ -130,6 +131,7 @@ def _create_djl_model(self) -> Type[Model]:
huggingface_hub_token=self.env_vars.get("HF_TOKEN"),
image_config=self.image_config,
vpc_config=self.vpc_config,
name=self.name
)

if not self.image_uri:
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self):
self.is_compiled = False
self.is_quantized = False
self.speculative_decoding_draft_model_source = None
self.name = None

@abstractmethod
def _prepare_for_mode(self, **kwargs):
Expand All @@ -147,7 +148,10 @@ def _is_jumpstart_model_id(self) -> bool:
def _create_pre_trained_js_model(self) -> Type[Model]:
"""Placeholder docstring"""
pysdk_model = JumpStartModel(
self.model, vpc_config=self.vpc_config, sagemaker_session=self.sagemaker_session
self.model,
name=self.name,
vpc_config=self.vpc_config,
sagemaker_session=self.sagemaker_session
)

self._original_deploy = pysdk_model.deploy
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def _create_model(self):
env=self.env_vars,
sagemaker_session=self.sagemaker_session,
predictor_cls=self._get_predictor,
name=self.name,
)

# store the modes in the model so that we may
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/serve/builder/tei_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self):
self.ram_usage_model_load = None
self.secret_key = None
self.role_arn = None
self.name = None

@abstractmethod
def _prepare_for_mode(self, *args, **kwargs):
Expand Down Expand Up @@ -105,6 +106,7 @@ def _create_tei_model(self, **kwargs) -> Type[Model]:
env=self.env_vars,
role=self.role_arn,
sagemaker_session=self.sagemaker_session,
name=self.name
)

logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/serve/builder/tf_serving_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self):
self.pysdk_model = None
self.schema_builder = None
self.env_vars = None
self.name = None

@abstractmethod
def _prepare_for_mode(self):
Expand Down Expand Up @@ -97,6 +98,7 @@ def _create_tensorflow_model(self):
env=self.env_vars,
sagemaker_session=self.sagemaker_session,
predictor_cls=self._get_tensorflow_predictor,
name=self.name
)

self.pysdk_model.mode = self.mode
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/serve/builder/tgi_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(self):
self.ram_usage_model_load = None
self.secret_key = None
self.role_arn = None
self.name = None

@abstractmethod
def _prepare_for_mode(self, *args, **kwargs):
Expand Down Expand Up @@ -142,6 +143,7 @@ def _create_tgi_model(self) -> Type[Model]:
env=self.env_vars,
role=self.role_arn,
sagemaker_session=self.sagemaker_session,
name=self.name
)

self._original_deploy = pysdk_model.deploy
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/serve/builder/transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(self):
self.schema_builder = None
self.inference_spec = None
self.shared_libs = None
self.name = None

@abstractmethod
def _prepare_for_mode(self, *args, **kwargs):
Expand All @@ -105,6 +106,7 @@ def _create_transformers_model(self) -> Type[Model]:
env=self.env_vars,
role=self.role_arn,
sagemaker_session=self.sagemaker_session,
name=self.name
)

logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
Expand Down
24 changes: 16 additions & 8 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
)

mock_model_obj = Mock()
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
mock_model_obj
if image_uri == mock_image_uri
and image_config == MOCK_IMAGE_CONFIG
Expand All @@ -326,6 +326,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
and role == mock_role_arn
and env == ENV_VARS
and sagemaker_session == mock_session
and "model-name-" in name
else None
)

Expand Down Expand Up @@ -425,13 +426,14 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
)

mock_model_obj = Mock()
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
mock_model_obj
if image_uri == mock_1p_dlc_image_uri
and model_data == model_data
and role == mock_role_arn
and env == ENV_VARS
and sagemaker_session == mock_session
and "model-name-" in name
else None
)

Expand Down Expand Up @@ -532,13 +534,14 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
)

mock_model_obj = Mock()
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
mock_model_obj
if image_uri == mock_image_uri
and model_data == model_data
and role == mock_role_arn
and env == ENV_VARS_INF_SPEC
and sagemaker_session == mock_session
and "model-name-" in name
else None
)

Expand Down Expand Up @@ -633,13 +636,14 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
)

mock_model_obj = Mock()
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
mock_model_obj
if image_uri == mock_image_uri
and model_data == model_data
and role == mock_role_arn
and env == ENV_VARS
and sagemaker_session == mock_session
and "model-name-" in name
else None
)

Expand Down Expand Up @@ -742,13 +746,14 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
)

mock_model_obj = Mock()
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
mock_model_obj
if image_uri == mock_image_uri
and model_data == model_data
and role == mock_role_arn
and env == ENV_VARS
and sagemaker_session == mock_session
and "model-name-" in name
else None
)

Expand Down Expand Up @@ -847,13 +852,14 @@ def test_build_happy_path_with_local_container_mode(
mock_mode.prepare.side_effect = lambda: None

mock_model_obj = Mock()
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
mock_model_obj
if image_uri == mock_image_uri
and model_data is None
and role == mock_role_arn
and env == {}
and sagemaker_session == mock_session
and "model-name-" in name
else None
)

Expand Down Expand Up @@ -968,13 +974,14 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
)

mock_model_obj = Mock()
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
mock_model_obj
if image_uri == mock_image_uri
and model_data is None
and role == mock_role_arn
and env == {}
and sagemaker_session == mock_session
and "model-name-" in name
else None
)

Expand Down Expand Up @@ -1119,13 +1126,14 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
)

mock_model_obj = Mock()
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
mock_model_obj
if image_uri == mock_image_uri
and model_data == model_data
and role == mock_role_arn
and env == ENV_VARS
and sagemaker_session == mock_session
and "model-name-" in name
else None
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def setUp(self):
self.instance.image_config = {}
self.instance.vpc_config = {}
self.instance.modes = {}
self.instance.name = "model-name-mock-uuid-hex"

@patch("os.makedirs")
@patch("os.path.exists")
Expand Down Expand Up @@ -71,5 +72,6 @@ def test_create_tensorflow_model(self, mock_model):
env=self.instance.env_vars,
sagemaker_session=self.instance.sagemaker_session,
predictor_cls=self.instance._get_tensorflow_predictor,
name="model-name-mock-uuid-hex"
)
self.assertEqual(model, mock_model.return_value)

0 comments on commit d8952e1

Please sign in to comment.