Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimization technique related validations. #4921

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7ec16e6
Enable quantization and compilation in the same optimization job via …
Sep 18, 2024
cf70f59
Require EULA acceptance when using a gated 1p draft model via ModelBu…
Nov 8, 2024
fcb5092
add accept_draft_model_eula to JumpStartModel when deployment config …
Nov 8, 2024
9489b8d
add map of valid optimization combinations
Nov 8, 2024
5512c26
Add ModelBuilder support for JumpStart-provided draft models.
Nov 9, 2024
c94a78b
Tweak draft model EULA validations and messaging. Remove redundant de…
Nov 9, 2024
d10c475
Add "Auto" speculative decoding ModelProvider option; add validations…
Nov 11, 2024
8fb27a0
Fix JumpStartModel.AdditionalModelDataSource model access config assi…
Nov 12, 2024
779f6d6
move the accept eula configurations into deploy flow
gwang111 Nov 12, 2024
aef3a90
Merge branch 'master' into QuicksilverV2
gwang111 Nov 12, 2024
b7b15b8
move the accept eula configurations into deploy flow
gwang111 Nov 12, 2024
748ea4b
Use correct bucket for SM/JS draft models and minor formatting/valida…
Nov 13, 2024
a7feb54
Remove obsolete docstring.
Nov 13, 2024
694b4f2
remove references to accept_draft_model_eula
gwang111 Nov 13, 2024
7b6aef1
renaming of eula fn and error msg
gwang111 Nov 13, 2024
ce47be5
Merge branch 'master' into QuicksilverV2
gwang111 Nov 13, 2024
1f75072
fix: pin testing deps (#4925)
benieric Nov 13, 2024
277e0b1
Revert "change: add TGI 2.4.0 image uri (#4922)" (#4926)
Captainia Nov 13, 2024
8f0083b
fix naming and messaging
gwang111 Nov 14, 2024
8b73f34
ModelBuilder speculative decoding UTs and minor fixes.
Nov 14, 2024
c06aef0
Merge branch 'master' into QuicksilverV2
gwang111 Nov 14, 2024
09a54dc
Fix set union.
Nov 14, 2024
3b147cd
add UTs for JumpStart deployment
gwang111 Nov 15, 2024
65cb5b3
fix formatting issues
gwang111 Nov 15, 2024
4d1e12b
address validation comments
gwang111 Nov 15, 2024
bf706ad
fix doc strings
gwang111 Nov 15, 2024
f121eb0
Add TRTLLM compilation + speculative decoding validation.
Nov 15, 2024
9148e70
address nits
gwang111 Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


from typing import Any, Dict, List, Optional, Union
from sagemaker_core.shapes import ModelAccessConfig
from sagemaker import environment_variables, image_uris, instance_types, model_uris, script_uris
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.base_deserializers import BaseDeserializer
Expand Down Expand Up @@ -53,11 +54,11 @@
add_hub_content_arn_tags,
add_jumpstart_model_info_tags,
get_default_jumpstart_session_with_user_agent_suffix,
get_neo_content_bucket,
get_top_ranked_config_name,
update_dict_if_key_not_present,
resolve_model_sagemaker_config_field,
verify_model_region_and_return_specs,
get_draft_model_content_bucket,
)

from sagemaker.jumpstart.factory.utils import (
Expand All @@ -70,7 +71,12 @@

from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
from sagemaker.session import Session
from sagemaker.utils import camel_case_to_pascal_case, name_from_base, format_tags, Tags
from sagemaker.utils import (
camel_case_to_pascal_case,
name_from_base,
format_tags,
Tags,
)
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker import resource_requirements
Expand Down Expand Up @@ -565,7 +571,9 @@ def _add_additional_model_data_sources_to_kwargs(
# Append speculative decoding data source from metadata
speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources()
for data_source in speculative_decoding_data_sources:
data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region))
data_source.s3_data_source.set_bucket(
get_draft_model_content_bucket(provider=data_source.provider, region=kwargs.region)
)
api_shape_additional_model_data_sources = (
[
camel_case_to_pascal_case(data_source.to_json())
Expand Down Expand Up @@ -648,6 +656,7 @@ def get_deploy_kwargs(
training_config_name: Optional[str] = None,
config_name: Optional[str] = None,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
) -> JumpStartModelDeployKwargs:
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""

Expand Down Expand Up @@ -684,6 +693,7 @@ def get_deploy_kwargs(
resources=resources,
config_name=config_name,
routing_config=routing_config,
model_access_configs=model_access_configs,
)
deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs)
deploy_kwargs.specs = verify_model_region_and_return_specs(
Expand Down
41 changes: 31 additions & 10 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pandas as pd
from botocore.exceptions import ClientError

from sagemaker_core.shapes import ModelAccessConfig
from sagemaker import payloads
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.base_deserializers import BaseDeserializer
Expand Down Expand Up @@ -51,6 +52,7 @@
add_instance_rate_stats_to_benchmark_metrics,
deployment_config_response_data,
_deployment_config_lru_cache,
_add_model_access_configs_to_model_data_sources,
)
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER
from sagemaker.jumpstart.enums import JumpStartModelType
Expand Down Expand Up @@ -540,12 +542,16 @@ def attach(
inferred_model_id = inferred_model_version = inferred_inference_component_name = None

if inference_component_name is None or model_id is None or model_version is None:
inferred_model_id, inferred_model_version, inferred_inference_component_name, _, _ = (
get_model_info_from_endpoint(
endpoint_name=endpoint_name,
inference_component_name=inference_component_name,
sagemaker_session=sagemaker_session,
)
(
inferred_model_id,
inferred_model_version,
inferred_inference_component_name,
_,
_,
) = get_model_info_from_endpoint(
endpoint_name=endpoint_name,
inference_component_name=inference_component_name,
sagemaker_session=sagemaker_session,
)

model_id = model_id or inferred_model_id
Expand Down Expand Up @@ -659,6 +665,7 @@ def deploy(
managed_instance_scaling: Optional[str] = None,
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
) -> PredictorBase:
"""Creates endpoint by calling base ``Model`` class `deploy` method.

Expand Down Expand Up @@ -755,6 +762,11 @@ def deploy(
(Default: EndpointType.MODEL_BASED).
routing_config (Optional[Dict]): Settings the control how the endpoint routes
incoming traffic to the instances that the endpoint hosts.
model_access_configs (Optional[Dict[str, ModelAccessConfig]]): For models that require
gwang111 marked this conversation as resolved.
Show resolved Hide resolved
ModelAccessConfig, provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }`
to indicate whether model terms of use have been accepted. The `accept_eula` value
must be explicitly defined as `True` in order to accept the end-user license
agreement (EULA) that some models require. (Default: None)

Raises:
MarketplaceModelSubscriptionError: If the caller is not subscribed to the model.
Expand Down Expand Up @@ -795,6 +807,7 @@ def deploy(
model_type=self.model_type,
config_name=self.config_name,
routing_config=routing_config,
model_access_configs=model_access_configs,
)
if (
self.model_type == JumpStartModelType.PROPRIETARY
Expand All @@ -804,6 +817,13 @@ def deploy(
f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models."
)

self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources(
self.additional_model_data_sources,
deploy_kwargs.model_access_configs,
deploy_kwargs.model_id,
deploy_kwargs.region,
)

try:
predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict())
except ClientError as e:
Expand Down Expand Up @@ -1016,10 +1036,11 @@ def _get_deployment_configs(
)

if metadata_config.benchmark_metrics:
err, metadata_config.benchmark_metrics = (
add_instance_rate_stats_to_benchmark_metrics(
self.region, metadata_config.benchmark_metrics
)
(
err,
metadata_config.benchmark_metrics,
) = add_instance_rate_stats_to_benchmark_metrics(
self.region, metadata_config.benchmark_metrics
)

config_components = metadata_config.config_components.get(config_name)
Expand Down
15 changes: 12 additions & 3 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
from sagemaker_core.shapes import ModelAccessConfig as CoreModelAccessConfig
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
from sagemaker.utils import (
S3_PREFIX,
Expand Down Expand Up @@ -1081,9 +1082,9 @@ def set_bucket(self, bucket: str) -> None:
class AdditionalModelDataSource(JumpStartDataHolderType):
"""Data class of additional model data source mirrors CreateModel API."""

SERIALIZATION_EXCLUSION_SET: Set[str] = set()
SERIALIZATION_EXCLUSION_SET = {"provider"}

__slots__ = ["channel_name", "s3_data_source"]
__slots__ = ["channel_name", "s3_data_source", "hosting_eula_key"]

def __init__(self, spec: Dict[str, Any]):
"""Initializes a AdditionalModelDataSource object.
Expand All @@ -1101,6 +1102,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
"""
self.channel_name: str = json_obj["channel_name"]
self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"])
self.hosting_eula_key: str = json_obj.get("hosting_eula_key")
self.provider: Dict = json_obj.get("provider", {})

def to_json(self, exclude_keys=True) -> Dict[str, Any]:
"""Returns json representation of AdditionalModelDataSource object."""
Expand All @@ -1119,7 +1122,9 @@ def to_json(self, exclude_keys=True) -> Dict[str, Any]:
class JumpStartModelDataSource(AdditionalModelDataSource):
"""Data class JumpStart additional model data source."""

SERIALIZATION_EXCLUSION_SET = {"artifact_version"}
SERIALIZATION_EXCLUSION_SET = AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET.union(
{"artifact_version"}
)

__slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__

Expand Down Expand Up @@ -2239,6 +2244,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"config_name",
"routing_config",
"specs",
"model_access_configs",
]

SERIALIZATION_EXCLUSION_SET = {
Expand All @@ -2252,6 +2258,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"sagemaker_session",
"training_instance_type",
"config_name",
"model_access_configs",
}

def __init__(
Expand Down Expand Up @@ -2290,6 +2297,7 @@ def __init__(
endpoint_type: Optional[EndpointType] = None,
config_name: Optional[str] = None,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None,
) -> None:
"""Instantiates JumpStartModelDeployKwargs object."""

Expand Down Expand Up @@ -2327,6 +2335,7 @@ def __init__(
self.endpoint_type = endpoint_type
self.config_name = config_name
self.routing_config = routing_config
self.model_access_configs = model_access_configs


class JumpStartEstimatorInitKwargs(JumpStartKwargs):
Expand Down
93 changes: 91 additions & 2 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# language governing permissions and limitations under the License.
"""This module contains utilities related to SageMaker JumpStart."""
from __future__ import absolute_import

from copy import copy
import logging
import os
Expand All @@ -22,6 +23,7 @@
from botocore.exceptions import ClientError
from packaging.version import Version
import botocore
from sagemaker_core.shapes import ModelAccessConfig
import sagemaker
from sagemaker.config.config_schema import (
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
Expand Down Expand Up @@ -55,6 +57,7 @@
TagsDict,
get_instance_rate_per_hour,
get_domain_for_region,
camel_case_to_pascal_case,
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.user_agent import get_user_agent_extra_suffix
Expand Down Expand Up @@ -555,11 +558,18 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
"""Returns EULA message to display if one is available, else empty string."""
if model_specs.hosting_eula_key is None:
return ""
return get_formatted_eula_message_template(
model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key
)


def get_formatted_eula_message_template(model_id: str, region: str, hosting_eula_key: str) -> str:
"""Returns a formatted EULA message."""
return (
f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). "
f"Model '{model_id}' requires accepting end-user license agreement (EULA). "
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
f"{get_domain_for_region(region)}"
f"/{model_specs.hosting_eula_key} for terms of use."
f"/{hosting_eula_key} for terms of use."
)


Expand Down Expand Up @@ -1525,3 +1535,82 @@ def wrapped_f(*args, **kwargs):
if _func is None:
return wrapper_cache
return wrapper_cache(_func)

gwang111 marked this conversation as resolved.
Show resolved Hide resolved

def _add_model_access_configs_to_model_data_sources(
model_data_sources: List[Dict[str, any]],
model_access_configs: Dict[str, ModelAccessConfig],
model_id: str,
region: str,
) -> List[Dict[str, any]]:
"""Iterate over the accept EULA configs to ensure all channels are matched

Args:
model_data_sources (DeploymentConfigMetadata): Model data sources that will be updated
model_access_configs (DeploymentConfigMetadata): Config holding accept_eula field
model_id (DeploymentConfigMetadata): Jumpstart model id.
region (str): Region where the user is operating in.
Returns:
List[Dict[str, Any]]: List of model data sources with accept EULA configs applied
Raise:
ValueError if at least one channel that requires EULA acceptance as not passed.
"""
if not model_data_sources:
return model_data_sources

acked_model_data_sources = []
for model_data_source in model_data_sources:
hosting_eula_key = model_data_source.get("HostingEulaKey")
mutable_model_data_source = model_data_source.copy()
if hosting_eula_key:
if (
not model_access_configs
or not model_access_configs.get(model_id)
or not model_access_configs.get(model_id).accept_eula
):
eula_message_template = (
"{model_source}{base_eula_message}{model_access_configs_message}"
)
model_access_config_entry = (
'"{model_id}":ModelAccessConfig(accept_eula=True)'.format(model_id=model_id)
)
raise ValueError(
eula_message_template.format(
model_source="Additional " if model_data_source.get("ChannelName") else "",
base_eula_message=get_formatted_eula_message_template(
model_id=model_id, region=region, hosting_eula_key=hosting_eula_key
),
model_access_configs_message=(
"Please add a ModelAccessConfig entry:"
f" {model_access_config_entry} "
"to model_access_configs to accept the EULA."
),
)
)
mutable_model_data_source.pop(
"HostingEulaKey"
) # pop when model access config is applied
mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump())
)
acked_model_data_sources.append(mutable_model_data_source)
else:
mutable_model_data_source.pop(
"HostingEulaKey"
) # pop when model access config is not applicable
acked_model_data_sources.append(mutable_model_data_source)
return acked_model_data_sources


def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
"""Returns the correct content bucket for a 1p draft model."""
neo_bucket = get_neo_content_bucket(region=region)
if not provider:
return neo_bucket
provider_name = provider.get("name", "")
if provider_name == "JumpStart":
classification = provider.get("classification", "ungated")
if classification == "gated":
return get_jumpstart_gated_content_bucket(region=region)
return get_jumpstart_content_bucket(region=region)
return neo_bucket
Loading
Loading