Skip to content

Commit

Permalink
fix formatting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
gwang111 committed Nov 15, 2024
1 parent 3b147cd commit 65cb5b3
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 134 deletions.
8 changes: 6 additions & 2 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,13 +1578,17 @@ def _add_model_access_configs_to_model_data_sources(
),
)
)
mutable_model_data_source.pop("HostingEulaKey") # pop when model access config is applied
mutable_model_data_source.pop(
"HostingEulaKey"
) # pop when model access config is applied
mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump())
)
acked_model_data_sources.append(mutable_model_data_source)
else:
mutable_model_data_source.pop("HostingEulaKey") # pop when model access config is not applicable
mutable_model_data_source.pop(
"HostingEulaKey"
) # pop when model access config is not applicable
acked_model_data_sources.append(mutable_model_data_source)
return acked_model_data_sources

Expand Down

This file was deleted.

47 changes: 27 additions & 20 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,9 +801,14 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
js_class_deploy = JumpStartModel.deploy
js_class_deploy_args = set(signature(js_class_deploy).parameters.keys())

assert js_class_deploy_args - parent_class_deploy_args - deploy_args_removed_at_deploy_time == set()
assert (parent_class_deploy_args - js_class_deploy_args - deploy_args_removed_at_deploy_time ==
deploy_args_to_skip)
assert (
js_class_deploy_args - parent_class_deploy_args - deploy_args_removed_at_deploy_time
== set()
)
assert (
parent_class_deploy_args - js_class_deploy_args - deploy_args_removed_at_deploy_time
== deploy_args_to_skip
)

@mock.patch(
"sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
Expand Down Expand Up @@ -1775,18 +1780,17 @@ def test_model_set_deployment_config(
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
def test_model_set_deployment_config_and_deploy_for_gated_draft_model(
self,
mock_model_deploy: mock.Mock,
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_get_manifest: mock.Mock,
mock_get_jumpstart_configs: mock.Mock,
self,
mock_model_deploy: mock.Mock,
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_get_manifest: mock.Mock,
mock_get_jumpstart_configs: mock.Mock,
):
# WHERE
mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec
mock_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs:
get_prototype_manifest(region, model_type)
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)
mock_model_deploy.return_value = default_predictor

Expand All @@ -1799,7 +1803,11 @@ def test_model_set_deployment_config_and_deploy_for_gated_draft_model(
assert model.config_name is None

# WHEN
model.deploy(model_access_configs={"pytorch-eqa-bert-base-cased":ModelAccessConfig(accept_eula=True)})
model.deploy(
model_access_configs={
"pytorch-eqa-bert-base-cased": ModelAccessConfig(accept_eula=True)
}
)

# THEN
mock_model_deploy.assert_called_once_with(
Expand All @@ -1822,18 +1830,17 @@ def test_model_set_deployment_config_and_deploy_for_gated_draft_model(
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
def test_model_set_deployment_config_and_deploy_for_gated_draft_model_no_model_access_configs(
self,
mock_model_deploy: mock.Mock,
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_get_manifest: mock.Mock,
mock_get_jumpstart_configs: mock.Mock,
self,
mock_model_deploy: mock.Mock,
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_get_manifest: mock.Mock,
mock_get_jumpstart_configs: mock.Mock,
):
# WHERE
mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec
mock_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs:
get_prototype_manifest(region, model_type)
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)
mock_model_deploy.return_value = default_predictor

Expand Down
123 changes: 66 additions & 57 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,48 +2175,46 @@ class TestAcceptEulaModelAccessConfig(TestCase):
MOCK_PUBLIC_MODEL_ID = "mock_public_model_id"
MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [
{
'ChannelName': 'draft_model',
'S3DataSource': {
'CompressionType': 'None',
'S3DataType': 'S3Prefix',
'S3Uri': 's3://jumpstart_bucket/path/to/public/resources/'
"ChannelName": "draft_model",
"S3DataSource": {
"CompressionType": "None",
"S3DataType": "S3Prefix",
"S3Uri": "s3://jumpstart_bucket/path/to/public/resources/",
},
'HostingEulaKey': None
"HostingEulaKey": None,
}
]
MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [
{
'ChannelName': 'draft_model',
'S3DataSource': {
'CompressionType': 'None',
'S3DataType': 'S3Prefix',
'S3Uri': 's3://jumpstart_bucket/path/to/public/resources/'
}
"ChannelName": "draft_model",
"S3DataSource": {
"CompressionType": "None",
"S3DataType": "S3Prefix",
"S3Uri": "s3://jumpstart_bucket/path/to/public/resources/",
},
}
]
MOCK_GATED_MODEL_ID = "mock_gated_model_id"
MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [
{
'ChannelName': 'draft_model',
'S3DataSource': {
'CompressionType': 'None',
'S3DataType': 'S3Prefix',
'S3Uri': 's3://jumpstart_bucket/path/to/gated/resources/'
"ChannelName": "draft_model",
"S3DataSource": {
"CompressionType": "None",
"S3DataType": "S3Prefix",
"S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/",
},
'HostingEulaKey': "fmhMetadata/eula/llama3_2Eula.txt"
"HostingEulaKey": "fmhMetadata/eula/llama3_2Eula.txt",
}
]
MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [
{
'ChannelName': 'draft_model',
'S3DataSource': {
'CompressionType': 'None',
'S3DataType': 'S3Prefix',
'S3Uri': 's3://jumpstart_bucket/path/to/gated/resources/',
'ModelAccessConfig': {
"AcceptEula": True
}
}
"ChannelName": "draft_model",
"S3DataSource": {
"CompressionType": "None",
"S3DataType": "S3Prefix",
"S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/",
"ModelAccessConfig": {"AcceptEula": True},
},
}
]

Expand All @@ -2232,14 +2230,17 @@ def test_public_additional_model_data_source_should_pass_through(self):
)

# THEN
assert additional_model_data_sources == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
assert (
additional_model_data_sources
== self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
)

def test_multiple_public_additional_model_data_source_should_pass_through_both(self):
# WHERE / WHEN
additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources(
model_data_sources=(
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
+ self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
),
model_access_configs=None,
model_id=self.MOCK_PUBLIC_MODEL_ID,
Expand All @@ -2248,23 +2249,24 @@ def test_multiple_public_additional_model_data_source_should_pass_through_both(s

# THEN
assert additional_model_data_sources == (
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL +
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
+ self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
)

def test_public_additional_model_data_source_with_model_access_config_should_ignored_it(self):
# WHERE / WHEN
additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources(
model_data_sources=self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL,
model_access_configs={
self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True)
},
model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)},
model_id=self.MOCK_GATED_MODEL_ID,
region=JUMPSTART_DEFAULT_REGION_NAME,
)

# THEN
assert additional_model_data_sources == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
assert (
additional_model_data_sources
== self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
)

def test_no_additional_model_data_source_should_pass_through(self):
# WHERE / WHEN
Expand All @@ -2284,62 +2286,65 @@ def test_gated_additional_model_data_source_should_accept_it(self):
# WHERE / WHEN
additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources(
model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL,
model_access_configs={
self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True)
},
model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)},
model_id=self.MOCK_GATED_MODEL_ID,
region=JUMPSTART_DEFAULT_REGION_NAME,
)

# THEN
assert additional_model_data_sources == self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
assert (
additional_model_data_sources
== self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
)

def test_multiple_gated_additional_model_data_source_should_accept_both(self):
# WHERE / WHEN
additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources(
model_data_sources=(
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
+ self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
),
model_access_configs={
self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True),
self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True)
self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True),
self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True),
},
model_id=self.MOCK_GATED_MODEL_ID,
region=JUMPSTART_DEFAULT_REGION_NAME,
)

# THEN
assert additional_model_data_sources == (
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL +
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
+ self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
)

# Mixed Positive Cases

def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other(self):
def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other(
self,
):
# WHERE / WHEN
additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources(
model_data_sources=(
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
+ self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
),
model_access_configs={
self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True)
},
model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)},
model_id=self.MOCK_GATED_MODEL_ID,
region=JUMPSTART_DEFAULT_REGION_NAME,
)

# THEN
assert additional_model_data_sources == (
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL +
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
+ self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
)

# Test Gated Negative Tests

def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error(self):
def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error(
self,
):
# WHERE / WHEN / THEN
with self.assertRaises(ValueError):
utils._add_model_access_configs_to_model_data_sources(
Expand All @@ -2354,33 +2359,37 @@ def test_multiple_mixed_additional_no_model_data_source_should_raise_value_error
with self.assertRaises(ValueError):
utils._add_model_access_configs_to_model_data_sources(
model_data_sources=(
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
+ self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
),
model_access_configs=None,
model_id=self.MOCK_GATED_MODEL_ID,
region=JUMPSTART_DEFAULT_REGION_NAME,
)

def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error(self):
def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error(
self,
):
# WHERE / WHEN / THEN
with self.assertRaises(ValueError):
utils._add_model_access_configs_to_model_data_sources(
model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL,
model_access_configs={
self.MOCK_PUBLIC_MODEL_ID:ModelAccessConfig(accept_eula=True)
self.MOCK_PUBLIC_MODEL_ID: ModelAccessConfig(accept_eula=True)
},
model_id=self.MOCK_GATED_MODEL_ID,
region=JUMPSTART_DEFAULT_REGION_NAME,
)

def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error(self):
def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error(
self,
):
# WHERE / WHEN / THEN
with self.assertRaises(ValueError):
utils._add_model_access_configs_to_model_data_sources(
model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL,
model_access_configs={
self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=False)
self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=False)
},
model_id=self.MOCK_GATED_MODEL_ID,
region=JUMPSTART_DEFAULT_REGION_NAME,
Expand Down
Loading

0 comments on commit 65cb5b3

Please sign in to comment.