diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index a5e9e1b6a4..82bc1fc174 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Union +from sagemaker_core.shapes import ModelAccessConfig from sagemaker import environment_variables, image_uris, instance_types, model_uris, script_uris from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -53,11 +54,11 @@ add_hub_content_arn_tags, add_jumpstart_model_info_tags, get_default_jumpstart_session_with_user_agent_suffix, - get_neo_content_bucket, get_top_ranked_config_name, update_dict_if_key_not_present, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, + get_draft_model_content_bucket, ) from sagemaker.jumpstart.factory.utils import ( @@ -70,7 +71,12 @@ from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig from sagemaker.session import Session -from sagemaker.utils import camel_case_to_pascal_case, name_from_base, format_tags, Tags +from sagemaker.utils import ( + camel_case_to_pascal_case, + name_from_base, + format_tags, + Tags, +) from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker import resource_requirements @@ -565,7 +571,9 @@ def _add_additional_model_data_sources_to_kwargs( # Append speculative decoding data source from metadata speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources() for data_source in speculative_decoding_data_sources: - data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region)) + data_source.s3_data_source.set_bucket( + get_draft_model_content_bucket(provider=data_source.provider, region=kwargs.region) + ) api_shape_additional_model_data_sources = ( [ camel_case_to_pascal_case(data_source.to_json()) @@ -648,6 +656,7 @@ def get_deploy_kwargs( training_config_name: Optional[str] = None, config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, + model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -684,6 +693,7 @@ def get_deploy_kwargs( resources=resources, config_name=config_name, routing_config=routing_config, + model_access_configs=model_access_configs, ) deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs) deploy_kwargs.specs = verify_model_region_and_return_specs( diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 486079718b..65bb156ee3 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -18,6 +18,7 @@ import pandas as pd from botocore.exceptions import ClientError +from sagemaker_core.shapes import ModelAccessConfig from sagemaker import payloads from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -51,6 +52,7 @@ add_instance_rate_stats_to_benchmark_metrics, deployment_config_response_data, _deployment_config_lru_cache, + _add_model_access_configs_to_model_data_sources, ) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType @@ -540,12 +542,16 @@ def attach( inferred_model_id = inferred_model_version = inferred_inference_component_name = None if inference_component_name is None or model_id is None or model_version is None: - inferred_model_id, inferred_model_version, inferred_inference_component_name, _, _ = ( - get_model_info_from_endpoint( - endpoint_name=endpoint_name, - inference_component_name=inference_component_name, - sagemaker_session=sagemaker_session, - ) + ( + inferred_model_id, + inferred_model_version, + inferred_inference_component_name, + _, + _, + ) = get_model_info_from_endpoint( + endpoint_name=endpoint_name, + inference_component_name=inference_component_name, + sagemaker_session=sagemaker_session, ) model_id = model_id or inferred_model_id @@ -659,6 +665,7 @@ def deploy( managed_instance_scaling: Optional[str] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, routing_config: Optional[Dict[str, Any]] = None, + model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None, ) -> PredictorBase: """Creates endpoint by calling base ``Model`` class `deploy` method. @@ -755,6 +762,11 @@ def deploy( (Default: EndpointType.MODEL_BASED). routing_config (Optional[Dict]): Settings the control how the endpoint routes incoming traffic to the instances that the endpoint hosts. + model_access_configs (Optional[Dict[str, ModelAccessConfig]]): For models that require + ModelAccessConfig, provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }` + to indicate whether model terms of use have been accepted. The `accept_eula` value + must be explicitly defined as `True` in order to accept the end-user license + agreement (EULA) that some models require. (Default: None) Raises: MarketplaceModelSubscriptionError: If the caller is not subscribed to the model. @@ -795,6 +807,7 @@ def deploy( model_type=self.model_type, config_name=self.config_name, routing_config=routing_config, + model_access_configs=model_access_configs, ) if ( self.model_type == JumpStartModelType.PROPRIETARY @@ -804,6 +817,13 @@ def deploy( f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models." ) + self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources( + self.additional_model_data_sources, + deploy_kwargs.model_access_configs, + deploy_kwargs.model_id, + deploy_kwargs.region, + ) + try: predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) except ClientError as e: @@ -1016,10 +1036,11 @@ def _get_deployment_configs( ) if metadata_config.benchmark_metrics: - err, metadata_config.benchmark_metrics = ( - add_instance_rate_stats_to_benchmark_metrics( - self.region, metadata_config.benchmark_metrics - ) + ( + err, + metadata_config.benchmark_metrics, + ) = add_instance_rate_stats_to_benchmark_metrics( + self.region, metadata_config.benchmark_metrics ) config_components = metadata_config.config_components.get(config_name) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index e77c407372..cb989ca4d4 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -17,6 +17,7 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union +from sagemaker_core.shapes import ModelAccessConfig as CoreModelAccessConfig from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard from sagemaker.utils import ( S3_PREFIX, @@ -1081,9 +1082,9 @@ def set_bucket(self, bucket: str) -> None: class AdditionalModelDataSource(JumpStartDataHolderType): """Data class of additional model data source mirrors CreateModel API.""" - SERIALIZATION_EXCLUSION_SET: Set[str] = set() + SERIALIZATION_EXCLUSION_SET = {"provider"} - __slots__ = ["channel_name", "s3_data_source"] + __slots__ = ["channel_name", "s3_data_source", "hosting_eula_key"] def __init__(self, spec: Dict[str, Any]): """Initializes a AdditionalModelDataSource object. @@ -1101,6 +1102,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: """ self.channel_name: str = json_obj["channel_name"] self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"]) + self.hosting_eula_key: str = json_obj.get("hosting_eula_key") + self.provider: Dict = json_obj.get("provider", {}) def to_json(self, exclude_keys=True) -> Dict[str, Any]: """Returns json representation of AdditionalModelDataSource object.""" @@ -1119,7 +1122,9 @@ def to_json(self, exclude_keys=True) -> Dict[str, Any]: class JumpStartModelDataSource(AdditionalModelDataSource): """Data class JumpStart additional model data source.""" - SERIALIZATION_EXCLUSION_SET = {"artifact_version"} + SERIALIZATION_EXCLUSION_SET = AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET.union( + {"artifact_version"} + ) __slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__ @@ -2239,6 +2244,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "config_name", "routing_config", "specs", + "model_access_configs", ] SERIALIZATION_EXCLUSION_SET = { @@ -2252,6 +2258,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "sagemaker_session", "training_instance_type", "config_name", + "model_access_configs", } def __init__( @@ -2290,6 +2297,7 @@ def __init__( endpoint_type: Optional[EndpointType] = None, config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, + model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -2327,6 +2335,7 @@ def __init__( self.endpoint_type = endpoint_type self.config_name = config_name self.routing_config = routing_config + self.model_access_configs = model_access_configs class JumpStartEstimatorInitKwargs(JumpStartKwargs): diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index b33d6563e5..dfe3d7f1dd 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """This module contains utilities related to SageMaker JumpStart.""" from __future__ import absolute_import + from copy import copy import logging import os @@ -22,6 +23,7 @@ from botocore.exceptions import ClientError from packaging.version import Version import botocore +from sagemaker_core.shapes import ModelAccessConfig import sagemaker from sagemaker.config.config_schema import ( MODEL_ENABLE_NETWORK_ISOLATION_PATH, @@ -55,6 +57,7 @@ TagsDict, get_instance_rate_per_hour, get_domain_for_region, + camel_case_to_pascal_case, ) from sagemaker.workflow import is_pipeline_variable from sagemaker.user_agent import get_user_agent_extra_suffix @@ -555,11 +558,18 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str: """Returns EULA message to display if one is available, else empty string.""" if model_specs.hosting_eula_key is None: return "" + return get_formatted_eula_message_template( + model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key + ) + + +def get_formatted_eula_message_template(model_id: str, region: str, hosting_eula_key: str) -> str: + """Returns a formatted EULA message.""" return ( - f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). " + f"Model '{model_id}' requires accepting end-user license agreement (EULA). " f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}." f"{get_domain_for_region(region)}" - f"/{model_specs.hosting_eula_key} for terms of use." + f"/{hosting_eula_key} for terms of use." ) @@ -1525,3 +1535,82 @@ def wrapped_f(*args, **kwargs): if _func is None: return wrapper_cache return wrapper_cache(_func) + + +def _add_model_access_configs_to_model_data_sources( + model_data_sources: List[Dict[str, any]], + model_access_configs: Dict[str, ModelAccessConfig], + model_id: str, + region: str, +) -> List[Dict[str, any]]: + """Iterate over the accept EULA configs to ensure all channels are matched + + Args: + model_data_sources (DeploymentConfigMetadata): Model data sources that will be updated + model_access_configs (DeploymentConfigMetadata): Config holding accept_eula field + model_id (DeploymentConfigMetadata): Jumpstart model id. + region (str): Region where the user is operating in. + Returns: + List[Dict[str, Any]]: List of model data sources with accept EULA configs applied + Raise: + ValueError if at least one channel that requires EULA acceptance as not passed. + """ + if not model_data_sources: + return model_data_sources + + acked_model_data_sources = [] + for model_data_source in model_data_sources: + hosting_eula_key = model_data_source.get("HostingEulaKey") + mutable_model_data_source = model_data_source.copy() + if hosting_eula_key: + if ( + not model_access_configs + or not model_access_configs.get(model_id) + or not model_access_configs.get(model_id).accept_eula + ): + eula_message_template = ( + "{model_source}{base_eula_message}{model_access_configs_message}" + ) + model_access_config_entry = ( + '"{model_id}":ModelAccessConfig(accept_eula=True)'.format(model_id=model_id) + ) + raise ValueError( + eula_message_template.format( + model_source="Additional " if model_data_source.get("ChannelName") else "", + base_eula_message=get_formatted_eula_message_template( + model_id=model_id, region=region, hosting_eula_key=hosting_eula_key + ), + model_access_configs_message=( + "Please add a ModelAccessConfig entry:" + f" {model_access_config_entry} " + "to model_access_configs to accept the EULA." + ), + ) + ) + mutable_model_data_source.pop( + "HostingEulaKey" + ) # pop when model access config is applied + mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = ( + camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump()) + ) + acked_model_data_sources.append(mutable_model_data_source) + else: + mutable_model_data_source.pop( + "HostingEulaKey" + ) # pop when model access config is not applicable + acked_model_data_sources.append(mutable_model_data_source) + return acked_model_data_sources + + +def get_draft_model_content_bucket(provider: Dict, region: str) -> str: + """Returns the correct content bucket for a 1p draft model.""" + neo_bucket = get_neo_content_bucket(region=region) + if not provider: + return neo_bucket + provider_name = provider.get("name", "") + if provider_name == "JumpStart": + classification = provider.get("classification", "ungated") + if classification == "gated": + return get_jumpstart_gated_content_bucket(region=region) + return get_jumpstart_content_bucket(region=region) + return neo_bucket diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index cfb43b813a..7d6a052023 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -48,6 +48,9 @@ _custom_speculative_decoding, SPECULATIVE_DRAFT_MODEL, _is_inferentia_or_trainium, + _jumpstart_speculative_decoding, + _deployment_config_contains_draft_model, + _is_draft_model_jumpstart_provided, ) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, @@ -718,24 +721,34 @@ def _optimize_for_jumpstart( f"Model '{self.model}' requires accepting end-user license agreement (EULA)." ) - is_compilation = (not quantization_config) and ( - (compilation_config is not None) or _is_inferentia_or_trainium(instance_type) + is_compilation = (compilation_config is not None) or _is_inferentia_or_trainium( + instance_type ) pysdk_model_env_vars = dict() if is_compilation: pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type) - optimization_config, override_env = _extract_optimization_config_and_env( - quantization_config, compilation_config + # optimization_config can contain configs for both quantization and compilation + optimization_config, quantization_override_env, compilation_override_env = ( + _extract_optimization_config_and_env(quantization_config, compilation_config) ) - if not optimization_config and is_compilation: - override_env = override_env or pysdk_model_env_vars - optimization_config = { - "ModelCompilationConfig": { - "OverrideEnvironment": override_env, - } - } + + if not optimization_config: + optimization_config = {} + + if not optimization_config.get("ModelCompilationConfig") and is_compilation: + # Fallback to default if override_env is None or empty + if not compilation_override_env: + compilation_override_env = pysdk_model_env_vars + + # Update optimization_config with ModelCompilationConfig + override_compilation_config = ( + {"OverrideEnvironment": compilation_override_env} + if compilation_override_env + else {} + ) + optimization_config["ModelCompilationConfig"] = override_compilation_config if speculative_decoding_config: self._set_additional_model_source(speculative_decoding_config) @@ -766,7 +779,7 @@ def _optimize_for_jumpstart( "OptimizationJobName": job_name, "ModelSource": model_source, "DeploymentInstanceType": self.instance_type, - "OptimizationConfigs": [optimization_config], + "OptimizationConfigs": [{k: v} for k, v in optimization_config.items()], "OutputConfig": output_config, "RoleArn": self.role_arn, } @@ -789,7 +802,13 @@ def _optimize_for_jumpstart( "AcceptEula": True } - optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env) + optimization_env_vars = _update_environment_variables( + optimization_env_vars, + { + **(quantization_override_env or {}), + **(compilation_override_env or {}), + }, + ) if optimization_env_vars: self.pysdk_model.env.update(optimization_env_vars) if quantization_config or is_compilation: @@ -813,9 +832,7 @@ def _is_gated_model(self, model=None) -> bool: return "private" in s3_uri def _set_additional_model_source( - self, - speculative_decoding_config: Optional[Dict[str, Any]] = None, - accept_eula: Optional[bool] = None, + self, speculative_decoding_config: Optional[Dict[str, Any]] = None ) -> None: """Set Additional Model Source to ``this`` model. @@ -825,9 +842,10 @@ def _set_additional_model_source( """ if speculative_decoding_config: model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) + channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources) - if model_provider == "sagemaker": + if model_provider in ["sagemaker", "auto"]: additional_model_data_sources = ( self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get( "AdditionalDataSources" @@ -840,24 +858,56 @@ def _set_additional_model_source( speculative_decoding_config ) if deployment_config: - self.pysdk_model.set_deployment_config( - config_name=deployment_config.get("DeploymentConfigName"), - instance_type=deployment_config.get("InstanceType"), - ) + if model_provider == "sagemaker" and _is_draft_model_jumpstart_provided( + deployment_config + ): + raise ValueError( + "No `Sagemaker` provided draft model was found for " + f"{self.model}. Try setting `ModelProvider` " + "to `Auto` instead." + ) + + try: + self.pysdk_model.set_deployment_config( + config_name=deployment_config.get("DeploymentConfigName"), + instance_type=deployment_config.get("InstanceType"), + ) + except ValueError as e: + raise ValueError( + f"{e} If using speculative_decoding_config, " + "accept the EULA by setting `AcceptEula`=True." + ) else: raise ValueError( "Cannot find deployment config compatible for optimization job." ) + else: + if model_provider == "sagemaker" and _is_draft_model_jumpstart_provided( + self.pysdk_model.deployment_config + ): + raise ValueError( + "No `Sagemaker` provided draft model was found for " + f"{self.model}. Try setting `ModelProvider` " + "to `Auto` instead." + ) self.pysdk_model.env.update( - {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}"} + {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"} ) self.pysdk_model.add_tags( - {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"}, + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": model_provider}, + ) + elif model_provider == "jumpstart": + _jumpstart_speculative_decoding( + model=self.pysdk_model, + speculative_decoding_config=speculative_decoding_config, + sagemaker_session=self.sagemaker_session, ) else: self.pysdk_model = _custom_speculative_decoding( - self.pysdk_model, speculative_decoding_config, accept_eula + self.pysdk_model, + speculative_decoding_config, + speculative_decoding_config.get("AcceptEula", False), ) def _find_compatible_deployment_config( @@ -875,15 +925,17 @@ def _find_compatible_deployment_config( for deployment_config in self.pysdk_model.list_deployment_configs(): image_uri = deployment_config.get("deployment_config", {}).get("ImageUri") - if _is_image_compatible_with_optimization_job(image_uri): + if _is_image_compatible_with_optimization_job( + image_uri + ) and _deployment_config_contains_draft_model(deployment_config): if ( - model_provider == "sagemaker" + model_provider in ["sagemaker", "auto"] and deployment_config.get("DeploymentArgs", {}).get("AdditionalDataSources") ) or model_provider == "custom": return deployment_config # There's no matching config from jumpstart to add sagemaker draft model location - if model_provider == "sagemaker": + if model_provider in ["sagemaker", "auto"]: return None # fall back to the default jumpstart model deployment config for optimization job diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index d1f1ab6ba2..61af6953a2 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -76,6 +76,7 @@ _is_s3_uri, _custom_speculative_decoding, _extract_speculative_draft_model_provider, + _jumpstart_speculative_decoding, ) from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.utils.hardware_detector import ( @@ -590,7 +591,7 @@ def _model_builder_deploy_wrapper( ) if "endpoint_logging" not in kwargs: - kwargs["endpoint_logging"] = True + kwargs["endpoint_logging"] = False predictor = self._original_deploy( *args, instance_type=instance_type, @@ -1235,9 +1236,6 @@ def _model_builder_optimize_wrapper( if self.mode != Mode.SAGEMAKER_ENDPOINT: raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.") - if quantization_config and compilation_config: - raise ValueError("Quantization config and compilation config are mutually exclusive.") - self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() self.instance_type = instance_type or self.instance_type self.role_arn = role_arn or self.role_arn @@ -1279,6 +1277,36 @@ def _model_builder_optimize_wrapper( ) if input_args: + optimization_instance_type = input_args["DeploymentInstanceType"] + + # Compilation using TRTLLM and Llama-3.1 is currently not supported. + # TRTLLM is used by Neo if the following are provided: + # 1) a GPU instance type + # 2) compilation config + gpu_instance_families = ["g5", "g6", "p4d", "p4de", "p5"] + is_gpu_instance = optimization_instance_type and any( + gpu_instance_family in optimization_instance_type + for gpu_instance_family in gpu_instance_families + ) + + # HF Model ID format = "meta-llama/Meta-Llama-3.1-8B" + # JS Model ID format = "meta-textgeneration-llama-3-1-8b" + llama_3_1_keywords = ["llama-3.1", "llama-3-1"] + is_llama_3_1 = self.model and any( + keyword in self.model.lower() for keyword in llama_3_1_keywords + ) + + if is_gpu_instance and self.model and self.is_compiled: + if is_llama_3_1: + raise ValueError( + "Compilation is not supported for Llama-3.1 with a GPU instance." + ) + if speculative_decoding_config: + raise ValueError( + "Compilation is not supported with speculative decoding with " + "a GPU instance." + ) + self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args) job_status = self.sagemaker_session.wait_for_optimization_job(job_name) return _generate_optimized_model(self.pysdk_model, job_status) @@ -1323,9 +1351,17 @@ def _optimize_for_hf( Returns: Optional[Dict[str, Any]]: Model optimization job input arguments. """ - self.pysdk_model = _custom_speculative_decoding( - self.pysdk_model, speculative_decoding_config, False - ) + if speculative_decoding_config: + if speculative_decoding_config.get("ModelProvider", "").lower() == "jumpstart": + _jumpstart_speculative_decoding( + model=self.pysdk_model, + speculative_decoding_config=speculative_decoding_config, + sagemaker_session=self.sagemaker_session, + ) + else: + self.pysdk_model = _custom_speculative_decoding( + self.pysdk_model, speculative_decoding_config, False + ) if quantization_config or compilation_config: create_optimization_job_args = { @@ -1342,11 +1378,18 @@ def _optimize_for_hf( model_source = _generate_model_source(self.pysdk_model.model_data, False) create_optimization_job_args["ModelSource"] = model_source - optimization_config, override_env = _extract_optimization_config_and_env( - quantization_config, compilation_config + optimization_config, quantization_override_env, compilation_override_env = ( + _extract_optimization_config_and_env(quantization_config, compilation_config) + ) + create_optimization_job_args["OptimizationConfigs"] = [ + {k: v} for k, v in optimization_config.items() + ] + self.pysdk_model.env.update( + { + **(quantization_override_env or {}), + **(compilation_override_env or {}), + } ) - create_optimization_job_args["OptimizationConfigs"] = [optimization_config] - self.pysdk_model.env.update(override_env) output_config = {"S3OutputLocation": output_path} if kms_key: diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 5781c0bade..14df6b3639 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -17,8 +17,10 @@ import logging from typing import Dict, Any, Optional, Union, List, Tuple -from sagemaker import Model +from sagemaker import Model, Session from sagemaker.enums import Tag +from sagemaker.jumpstart.utils import accessors, get_eula_message + logger = logging.getLogger(__name__) @@ -58,6 +60,46 @@ def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool return "djl-inference:" in image_uri and ("-lmi" in image_uri or "-neuronx-" in image_uri) +def _deployment_config_contains_draft_model(deployment_config: Optional[Dict]) -> bool: + """Checks whether a deployment config contains a speculative decoding draft model. + + Args: + deployment_config (Dict): The deployment config to check. + + Returns: + bool: Whether the deployment config contains a draft model or not. + """ + if deployment_config is None: + return False + deployment_args = deployment_config.get("DeploymentArgs", {}) + additional_data_sources = deployment_args.get("AdditionalDataSources") + + return "speculative_decoding" in additional_data_sources if additional_data_sources else False + + +def _is_draft_model_jumpstart_provided(deployment_config: Optional[Dict]) -> bool: + """Checks whether a deployment config's draft model is provided by JumpStart. + + Args: + deployment_config (Dict): The deployment config to check. + + Returns: + bool: Whether the draft model is provided by JumpStart or not. + """ + if deployment_config is None: + return False + + additional_model_data_sources = deployment_config.get("DeploymentArgs", {}).get( + "AdditionalDataSources" + ) + for source in additional_model_data_sources.get("speculative_decoding", []): + if source["channel_name"] == "draft_model": + if source.get("provider", {}).get("name") == "JumpStart": + return True + continue + return False + + def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -> Model: """Generates a new optimization model. @@ -164,12 +206,72 @@ def _extract_speculative_draft_model_provider( if speculative_decoding_config is None: return None - if speculative_decoding_config.get( - "ModelProvider" - ) == "Custom" or speculative_decoding_config.get("ModelSource"): + model_provider = speculative_decoding_config.get("ModelProvider", "").lower() + + if model_provider == "jumpstart": + return "jumpstart" + + if model_provider == "custom" or speculative_decoding_config.get("ModelSource"): return "custom" - return "sagemaker" + if model_provider == "sagemaker": + return "sagemaker" + + return "auto" + + +def _extract_additional_model_data_source_s3_uri( + additional_model_data_source: Optional[Dict] = None, +) -> Optional[str]: + """Extracts model data source s3 uri from a model data source in Pascal case. + + Args: + additional_model_data_source (Optional[Dict]): A model data source. + + Returns: + str: S3 uri of the model resources. + """ + if ( + additional_model_data_source is None + or additional_model_data_source.get("S3DataSource", None) is None + ): + return None + + return additional_model_data_source.get("S3DataSource").get("S3Uri") + + +def _extract_deployment_config_additional_model_data_source_s3_uri( + additional_model_data_source: Optional[Dict] = None, +) -> Optional[str]: + """Extracts model data source s3 uri from a model data source in snake case. + + Args: + additional_model_data_source (Optional[Dict]): A model data source. + + Returns: + str: S3 uri of the model resources. + """ + if ( + additional_model_data_source is None + or additional_model_data_source.get("s3_data_source", None) is None + ): + return None + + return additional_model_data_source.get("s3_data_source").get("s3_uri", None) + + +def _is_draft_model_gated( + draft_model_config: Optional[Dict] = None, +) -> bool: + """Extracts model gated-ness from draft model data source. + + Args: + draft_model_config (Optional[Dict]): A model data source. + + Returns: + bool: Whether the draft model is gated or not. + """ + return "hosting_eula_key" in draft_model_config if draft_model_config else False def _extracts_and_validates_speculative_model_source( @@ -238,7 +340,7 @@ def _generate_additional_model_data_sources( }, } if accept_eula: - additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"ACCEPT_EULA": True} + additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True} return [additional_model_data_source] @@ -260,7 +362,7 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool: def _extract_optimization_config_and_env( quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None -) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]: +) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]: """Extracts optimization config and environment variables. Args: @@ -268,18 +370,28 @@ def _extract_optimization_config_and_env( compilation_config (Optional[Dict]): The compilation config. Returns: - Optional[Tuple[Optional[Dict], Optional[Dict]]]: + Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]: The optimization config and environment variables. """ - if quantization_config: - return {"ModelQuantizationConfig": quantization_config}, quantization_config.get( - "OverrideEnvironment" - ) - if compilation_config: - return {"ModelCompilationConfig": compilation_config}, compilation_config.get( - "OverrideEnvironment" - ) - return None, None + optimization_config = {} + quantization_override_env = ( + quantization_config.get("OverrideEnvironment") if quantization_config else None + ) + compilation_override_env = ( + compilation_config.get("OverrideEnvironment") if compilation_config else None + ) + + if quantization_config is not None: + optimization_config["ModelQuantizationConfig"] = quantization_config + + if compilation_config is not None: + optimization_config["ModelCompilationConfig"] = compilation_config + + # Return optimization config dict and environment variables if either is present + if optimization_config: + return optimization_config, quantization_override_env, compilation_override_env + + return None, None, None def _custom_speculative_decoding( @@ -300,6 +412,8 @@ def _custom_speculative_decoding( speculative_decoding_config ) + accept_eula = speculative_decoding_config.get("AcceptEula", accept_eula) + if _is_s3_uri(additional_model_source): channel_name = _generate_channel_name(model.additional_model_data_sources) speculative_draft_model = f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}" @@ -316,3 +430,62 @@ def _custom_speculative_decoding( ) return model + + +def _jumpstart_speculative_decoding( + model=Model, + speculative_decoding_config: Optional[Dict[str, Any]] = None, + sagemaker_session: Optional[Session] = None, +): + """Modifies the given model for speculative decoding config with JumpStart provider. + + Args: + model (Model): The model. + speculative_decoding_config (Optional[Dict]): The speculative decoding config. + sagemaker_session (Optional[Session]): Sagemaker session for execution. + """ + if speculative_decoding_config: + js_id = speculative_decoding_config.get("ModelID") + if not js_id: + raise ValueError( + "`ModelID` is a required field in `speculative_decoding_config` when " + "using JumpStart as draft model provider." + ) + model_version = speculative_decoding_config.get("ModelVersion", "*") + accept_eula = speculative_decoding_config.get("AcceptEula", False) + channel_name = _generate_channel_name(model.additional_model_data_sources) + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + model_id=js_id, + version=model_version, + region=sagemaker_session.boto_region_name, + sagemaker_session=sagemaker_session, + ) + model_spec_json = model_specs.to_json() + + js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket() + + if model_spec_json.get("gated_bucket", False): + if not accept_eula: + eula_message = get_eula_message( + model_specs=model_specs, region=sagemaker_session.boto_region_name + ) + raise ValueError( + f"{eula_message} Set `AcceptEula`=True in " + f"speculative_decoding_config once acknowledged." + ) + js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket() + + key_prefix = model_spec_json.get("hosting_prepacked_artifact_key") + model.additional_model_data_sources = _generate_additional_model_data_sources( + f"s3://{js_bucket}/{key_prefix}", + channel_name, + accept_eula, + ) + + model.env.update( + {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"} + ) + model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"}, + ) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 4c573bca8c..09e63f8f59 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -19,6 +19,7 @@ import pandas as pd from mock import MagicMock, Mock import pytest +from sagemaker_core.shapes import ModelAccessConfig from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.jumpstart.artifacts.environment_variables import ( _retrieve_default_environment_variables, @@ -54,6 +55,7 @@ get_base_deployment_configs, get_base_spec_with_prototype_configs_with_missing_benchmarks, append_instance_stat_metrics, + append_gated_draft_model_specs_to_jumpstart_model_spec, ) import boto3 @@ -772,6 +774,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): init_args_to_skip: Set[str] = set(["model_reference_arn"]) deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn"]) + deploy_args_removed_at_deploy_time: Set[str] = set(["model_access_configs"]) parent_class_init = Model.__init__ parent_class_init_args = set(signature(parent_class_init).parameters.keys()) @@ -798,8 +801,14 @@ def test_jumpstart_model_kwargs_match_parent_class(self): js_class_deploy = JumpStartModel.deploy js_class_deploy_args = set(signature(js_class_deploy).parameters.keys()) - assert js_class_deploy_args - parent_class_deploy_args == set() - assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip + assert ( + js_class_deploy_args - parent_class_deploy_args - deploy_args_removed_at_deploy_time + == set() + ) + assert ( + parent_class_deploy_args - js_class_deploy_args - deploy_args_removed_at_deploy_time + == deploy_args_to_skip + ) @mock.patch( "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} @@ -1762,6 +1771,91 @@ def test_model_set_deployment_config( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config_and_deploy_for_gated_draft_model( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + # WHERE + mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id = "pytorch-eqa-bert-base-cased" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name is None + + # WHEN + model.deploy( + model_access_configs={ + "pytorch-eqa-bert-base-cased": ModelAccessConfig(accept_eula=True) + } + ) + + # THEN + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config_and_deploy_for_gated_draft_model_no_model_access_configs( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + # WHERE + mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id = "pytorch-eqa-bert-base-cased" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name is None + + # WHEN / THEN + with self.assertRaises(ValueError): + model.deploy() + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" @@ -1810,6 +1904,7 @@ def test_model_deployment_config_additional_model_data_source( "S3Uri": "s3://sagemaker-sd-models-prod-us-west-2/key/to/draft/model/artifact/", "ModelAccessConfig": {"AcceptEula": False}, }, + "HostingEulaKey": None, } ], ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index fe2ba749cd..67681e2b7b 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -20,6 +20,7 @@ import pytest import boto3 import random +from sagemaker_core.shapes import ModelAccessConfig from sagemaker import session from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( @@ -2168,3 +2169,228 @@ def test_get_domain_for_region(self): " https://jumpstart-cache-prod-cn-north-1.s3.cn-north-1.amazonaws.com.cn/some-eula-key " "for terms of use.", ) + + +class TestAcceptEulaModelAccessConfig(TestCase): + MOCK_PUBLIC_MODEL_ID = "mock_public_model_id" + MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/public/resources/", + }, + "HostingEulaKey": None, + } + ] + MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/public/resources/", + }, + } + ] + MOCK_GATED_MODEL_ID = "mock_gated_model_id" + MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/", + }, + "HostingEulaKey": "fmhMetadata/eula/llama3_2Eula.txt", + } + ] + MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/", + "ModelAccessConfig": {"AcceptEula": True}, + }, + } + ] + + # Public Positive Cases + + def test_public_additional_model_data_source_should_pass_through(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs=None, + model_id=self.MOCK_PUBLIC_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert ( + additional_model_data_sources + == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + def test_multiple_public_additional_model_data_source_should_pass_through_both(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + ), + model_access_configs=None, + model_id=self.MOCK_PUBLIC_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert additional_model_data_sources == ( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + def test_public_additional_model_data_source_with_model_access_config_should_ignore_it(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)}, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert ( + additional_model_data_sources + == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + def test_no_additional_model_data_source_should_pass_through(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=None, + model_access_configs=None, + model_id=self.MOCK_PUBLIC_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert not additional_model_data_sources + + # Gated Positive Cases + + def test_gated_additional_model_data_source_should_accept_it(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)}, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert ( + additional_model_data_sources + == self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + def test_multiple_gated_additional_model_data_source_should_accept_both(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=( + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + ), + model_access_configs={ + self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True), + self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True), + }, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert additional_model_data_sources == ( + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + # Mixed Positive Cases + + def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other( + self, + ): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + ), + model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)}, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert additional_model_data_sources == ( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + # Test Gated Negative Tests + + def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error( + self, + ): + # WHERE / WHEN / THEN + with self.assertRaises(ValueError): + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs=None, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + def test_multiple_mixed_additional_no_model_data_source_should_raise_value_error(self): + # WHERE / WHEN / THEN + with self.assertRaises(ValueError): + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + ), + model_access_configs=None, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error( + self, + ): + # WHERE / WHEN / THEN + with self.assertRaises(ValueError): + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs={ + self.MOCK_PUBLIC_MODEL_ID: ModelAccessConfig(accept_eula=True) + }, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error( + self, + ): + # WHERE / WHEN / THEN + with self.assertRaises(ValueError): + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs={ + self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=False) + }, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index de274f0374..bd870dc461 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -32,6 +32,7 @@ DeploymentConfigMetadata, JumpStartModelDeployKwargs, JumpStartBenchmarkStat, + JumpStartAdditionalDataSources, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest @@ -436,3 +437,27 @@ def append_instance_stat_metrics( ) ) return metrics + + +def append_gated_draft_model_specs_to_jumpstart_model_spec(*args, **kwargs): + augmented_spec = get_prototype_model_spec(*args, **kwargs) + + gated_s3_uri = "meta-textgeneration/meta-textgeneration-llama-3-2-1b-instruct/artifacts/inference-prepack/v1.0.0/" + augmented_spec.hosting_additional_data_sources = JumpStartAdditionalDataSources( + spec={ + "speculative_decoding": [ + { + "channel_name": "draft_model", + "provider": {"name": "JumpStart", "classification": "gated"}, + "artifact_version": "v1", + "hosting_eula_key": "fmhMetadata/eula/llama3_2Eula.txt", + "s3_data_source": { + "s3_uri": gated_s3_uri, + "compression_type": "None", + "s3_data_type": "S3Prefix", + }, + } + ] + } + ) + return augmented_spec diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 248955c273..25bc67d22d 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -25,7 +25,11 @@ LocalModelOutOfMemoryException, LocalModelInvocationException, ) -from tests.unit.sagemaker.serve.constants import DEPLOYMENT_CONFIGS +from tests.unit.sagemaker.serve.constants import ( + DEPLOYMENT_CONFIGS, + OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL, + CAMEL_CASE_ADDTL_DRAFT_MODEL_DATA_SOURCES, +) mock_model_id = "huggingface-llm-amazon-falconlite" mock_t5_model_id = "google/flan-t5-xxl" @@ -1198,6 +1202,124 @@ def test_optimize_quantize_for_jumpstart( self.assertIsNotNone(out_put) + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder._jumpstart_speculative_decoding", + return_value=True, + ) + def test_jumpstart_model_provider_calls_jumpstart_speculative_decoding( + self, + mock_js_speculative_decoding, + mock_pretrained_js_model, + mock_is_js_model, + mock_serve_settings, + mock_capture_telemetry, + ): + mock_sagemaker_session = Mock() + mock_pysdk_model = Mock() + mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"} + mock_pysdk_model.model_data = mock_model_data + mock_pysdk_model.image_uri = mock_tgi_image_uri + mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS + mock_pysdk_model.deployment_config = OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL + mock_pysdk_model.additional_model_data_sources = CAMEL_CASE_ADDTL_DRAFT_MODEL_DATA_SOURCES + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + model_builder.pysdk_model = mock_pysdk_model + + model_builder._optimize_for_jumpstart( + accept_eula=True, + speculative_decoding_config={ + "ModelProvider": "JumpStart", + "ModelID": "meta-textgeneration-llama-3-2-1b", + "AcceptEula": False, + }, + ) + + mock_js_speculative_decoding.assert_called_once() + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_quantize_and_compile_for_jumpstart( + self, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + mock_metadata_config = Mock() + mock_metadata_config.resolved_config = { + "supported_inference_instance_types": ["ml.inf2.48xlarge"], + "hosting_neuron_model_id": "huggingface-llmneuron-mistral-7b", + } + + mock_pysdk_model = Mock() + mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"} + mock_pysdk_model.model_data = mock_model_data + mock_pysdk_model.image_uri = mock_tgi_image_uri + mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS + mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pysdk_model.config_name = "config_name" + mock_pysdk_model._metadata_configs = {"config_name": mock_metadata_config} + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_jumpstart( + accept_eula=True, + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, + output_path="s3://bucket/code/", + ) + + self.assertIsNotNone(out_put) + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) @patch( @@ -1383,3 +1505,103 @@ def test_optimize_compile_for_jumpstart_with_neuron_env( self.assertEqual(optimized_model.env["OPTION_ROLLING_BATCH"], "auto") self.assertEqual(optimized_model.env["OPTION_MAX_ROLLING_BATCH_SIZE"], "4") self.assertEqual(optimized_model.env["OPTION_NEURON_OPTIMIZE_LEVEL"], "2") + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model") + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_optimize_compile_for_jumpstart_without_compilation_config( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_js_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + mock_metadata_config = Mock() + mock_sagemaker_session.wait_for_optimization_job.side_effect = ( + lambda *args: mock_optimization_job_response + ) + + mock_metadata_config.resolved_config = { + "supported_inference_instance_types": ["ml.inf2.48xlarge"], + "hosting_neuron_model_id": "huggingface-llmneuron-mistral-7b", + } + + mock_js_model.return_value = MagicMock() + mock_js_model.return_value.env = { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + } + + mock_pre_trained_model.return_value = MagicMock() + mock_pre_trained_model.return_value.env = dict() + mock_pre_trained_model.return_value.config_name = "config_name" + mock_pre_trained_model.return_value.model_data = mock_model_data + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.return_value = ( + DEPLOYMENT_CONFIGS + ) + mock_pre_trained_model.return_value.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value._metadata_configs = { + "config_name": mock_metadata_config + } + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + optimized_model = model_builder.optimize( + accept_eula=True, + instance_type="ml.inf2.24xlarge", + output_path="s3://bucket/code/", + ) + + self.assertEqual( + optimized_model.image_uri, + mock_optimization_job_response["OptimizationOutput"]["RecommendedInferenceImage"], + ) + self.assertEqual( + optimized_model.model_data["S3DataSource"]["S3Uri"], + mock_optimization_job_response["OutputConfig"]["S3OutputLocation"], + ) + self.assertEqual(optimized_model.env["SAGEMAKER_PROGRAM"], "inference.py") + self.assertEqual(optimized_model.env["ENDPOINT_SERVER_TIMEOUT"], "3600") + self.assertEqual(optimized_model.env["MODEL_CACHE_ROOT"], "/opt/ml/model") + self.assertEqual(optimized_model.env["SAGEMAKER_ENV"], "1") + self.assertEqual(optimized_model.env["HF_MODEL_ID"], "/opt/ml/model") + self.assertEqual(optimized_model.env["SAGEMAKER_MODEL_SERVER_WORKERS"], "1") diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index b50aa17c34..2da09aece3 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -2650,21 +2650,75 @@ def test_optimize_local_mode(self, mock_get_serve_setting): ), ) + @patch.object(ModelBuilder, "_prepare_for_mode") @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) - def test_optimize_exclusive_args(self, mock_get_serve_setting): - mock_sagemaker_session = Mock() + def test_optimize_for_hf_with_both_quantization_and_compilation( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-8B-Instruc"} + model_builder = ModelBuilder( - model="meta-textgeneration-llama-3-70b", - sagemaker_session=mock_sagemaker_session, + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HF_TOKEN": "token"}, + model_metadata={ + "CUSTOM_MODEL_PATH": "s3://bucket/path/", + }, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", ) - self.assertRaisesRegex( - ValueError, - "Quantization config and compilation config are mutually exclusive.", - lambda: model_builder.optimize( - quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, - compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, - ), + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_hf( + job_name="job_name-123", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, + output_path="s3://bucket/code/", + ) + + self.assertEqual(model_builder.env_vars["HF_TOKEN"], "token") + self.assertEqual(model_builder.role_arn, "role-arn") + self.assertEqual(model_builder.instance_type, "ml.g5.2xlarge") + self.assertEqual(model_builder.pysdk_model.env["OPTION_QUANTIZE"], "awq") + self.assertEqual(model_builder.pysdk_model.env["OPTION_TENSOR_PARALLEL_DEGREE"], "2") + self.assertEqual( + out_put, + { + "OptimizationJobName": "job_name-123", + "DeploymentInstanceType": "ml.g5.2xlarge", + "RoleArn": "role-arn", + "ModelSource": {"S3": {"S3Uri": "s3://bucket/code/code/"}}, + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"} + } + }, + { + "ModelCompilationConfig": { + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"} + } + }, + ], + "OutputConfig": {"S3OutputLocation": "s3://bucket/code/"}, + }, ) @patch.object(ModelBuilder, "_prepare_for_mode") @@ -2786,3 +2840,109 @@ def test_optimize_for_hf_without_custom_s3_path( "OutputConfig": {"S3OutputLocation": "s3://bucket/code/"}, }, ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-1-8B-Instruct"} + + sample_input = {"inputs": "dummy prompt", "parameters": {}} + + sample_output = [{"generated_text": "dummy response"}] + + dummy_schema_builder = SchemaBuilder(sample_input, sample_output) + + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-1-8B-Instruct", + schema_builder=dummy_schema_builder, + env_vars={"HF_TOKEN": "token"}, + model_metadata={ + "CUSTOM_MODEL_PATH": "s3://bucket/path/", + }, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + + model_builder.pysdk_model = mock_pysdk_model + + self.assertRaisesRegex( + ValueError, + "Compilation is not supported for Llama-3.1 with a GPU instance.", + lambda: model_builder.optimize( + job_name="job_name-123", + compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, + output_path="s3://bucket/code/", + ), + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_with_gpu_instance_and_compilation_with_speculative_decoding( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "modelid"} + + sample_input = {"inputs": "dummy prompt", "parameters": {}} + + sample_output = [{"generated_text": "dummy response"}] + + dummy_schema_builder = SchemaBuilder(sample_input, sample_output) + + model_builder = ModelBuilder( + model="modelid", + schema_builder=dummy_schema_builder, + env_vars={"HF_TOKEN": "token"}, + model_metadata={ + "CUSTOM_MODEL_PATH": "s3://bucket/path/", + }, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + + model_builder.pysdk_model = mock_pysdk_model + + self.assertRaisesRegex( + ValueError, + "Compilation is not supported with speculative decoding with a GPU instance.", + lambda: model_builder.optimize( + job_name="job_name-123", + speculative_decoding_config={ + "ModelProvider": "custom", + "ModelSource": "s3://data-source", + }, + compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, + output_path="s3://bucket/code/", + ), + ) diff --git a/tests/unit/sagemaker/serve/constants.py b/tests/unit/sagemaker/serve/constants.py index 5a4679747b..3e776eaa46 100644 --- a/tests/unit/sagemaker/serve/constants.py +++ b/tests/unit/sagemaker/serve/constants.py @@ -165,3 +165,153 @@ }, }, ] +NON_OPTIMIZED_DEPLOYMENT_CONFIG = { + "ConfigName": "neuron-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentArgs": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, +} +OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL = { + "DeploymentConfigName": "lmi-optimized", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "djl-inference:0.29.0-lmi11.0.0-cu124", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-1-70b/artifacts/inference-prepack/v2.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "ModelPackageArn": None, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.g6.2xlarge", + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 131072, + "NumberOfAcceleratorDevicesRequired": 1, + }, + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "AdditionalDataSources": { + "speculative_decoding": [ + { + "channel_name": "draft_model", + "provider": {"name": "JumpStart", "classification": "gated"}, + "artifact_version": "v1", + "hosting_eula_key": "fmhMetadata/eula/llama3_2Eula.txt", + "s3_data_source": { + "s3_uri": "meta-textgeneration/meta-textgeneration-llama-3-2-1b/artifacts/" + "inference-prepack/v1.0.0/", + "compression_type": "None", + "s3_data_type": "S3Prefix", + }, + } + ] + }, + }, + "AccelerationConfigs": [ + { + "type": "Compilation", + "enabled": False, + "diy_workflow_overrides": { + "gpu-lmi-trt": { + "enabled": False, + "reason": "TRT-LLM 0.11.0 in LMI v11 does not support llama 3.1", + } + }, + }, + { + "type": "Speculative-Decoding", + "enabled": True, + "diy_workflow_overrides": { + "gpu-lmi-trt": { + "enabled": False, + "reason": "LMI v11 does not support Speculative Decoding for TRT", + } + }, + }, + { + "type": "Quantization", + "enabled": False, + "diy_workflow_overrides": { + "gpu-lmi-trt": { + "enabled": False, + "reason": "TRT-LLM 0.11.0 in LMI v11 does not support llama 3.1", + } + }, + }, + ], + "BenchmarkMetrics": {"ml.g6.2xlarge": None}, +} +GATED_DRAFT_MODEL_CONFIG = { + "channel_name": "draft_model", + "provider": {"name": "JumpStart", "classification": "gated"}, + "artifact_version": "v1", + "hosting_eula_key": "fmhMetadata/eula/llama3_2Eula.txt", + "s3_data_source": { + "s3_uri": "meta-textgeneration/meta-textgeneration-llama-3-2-1b/artifacts/" + "inference-prepack/v1.0.0/", + "compression_type": "None", + "s3_data_type": "S3Prefix", + }, +} +NON_GATED_DRAFT_MODEL_CONFIG = { + "channel_name": "draft_model", + "s3_data_source": { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "s3://sagemaker-sd-models-beta-us-west-2/" + "sagemaker-speculative-decoding-llama3-small-v3/", + }, +} +CAMEL_CASE_ADDTL_DRAFT_MODEL_DATA_SOURCES = [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": "meta-textgeneration/meta-textgeneration-llama-3-2-1b/artifacts/" + "inference-prepack/v1.0.0/", + "CompressionType": "None", + "S3DataType": "S3Prefix", + }, + } +] diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index a8dc6d74f4..7cf0406f42 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -31,6 +31,15 @@ _is_optimized, _custom_speculative_decoding, _is_inferentia_or_trainium, + _is_draft_model_gated, + _deployment_config_contains_draft_model, + _jumpstart_speculative_decoding, +) +from tests.unit.sagemaker.serve.constants import ( + GATED_DRAFT_MODEL_CONFIG, + NON_GATED_DRAFT_MODEL_CONFIG, + OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL, + NON_OPTIMIZED_DEPLOYMENT_CONFIG, ) mock_optimization_job_output = { @@ -180,6 +189,9 @@ def test_update_environment_variables(env, new_env, output_env): ({"ModelProvider": "SageMaker"}, "sagemaker"), ({"ModelProvider": "Custom"}, "custom"), ({"ModelSource": "s3://"}, "custom"), + ({"ModelProvider": "JumpStart"}, "jumpstart"), + ({"ModelProvider": "asdf"}, "auto"), + ({"ModelProvider": "Auto"}, "auto"), (None, None), ], ) @@ -224,7 +236,7 @@ def test_generate_additional_model_data_sources(): "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "S3DataType": "S3Prefix", "CompressionType": "None", - "ModelAccessConfig": {"ACCEPT_EULA": True}, + "ModelAccessConfig": {"AcceptEula": True}, }, } ] @@ -261,7 +273,18 @@ def test_is_s3_uri(s3_uri, expected): @pytest.mark.parametrize( - "quantization_config, compilation_config, expected_config, expected_env", + "draft_model_config, expected", + [ + (GATED_DRAFT_MODEL_CONFIG, True), + (NON_GATED_DRAFT_MODEL_CONFIG, False), + ], +) +def test_is_draft_model_gated(draft_model_config, expected): + assert _is_draft_model_gated(draft_model_config) is expected + + +@pytest.mark.parametrize( + "quantization_config, compilation_config, expected_config, expected_quant_env, expected_compilation_env", [ ( None, @@ -277,6 +300,7 @@ def test_is_s3_uri(s3_uri, expected): } }, }, + None, { "OPTION_TENSOR_PARALLEL_DEGREE": "2", }, @@ -298,17 +322,162 @@ def test_is_s3_uri(s3_uri, expected): { "OPTION_TENSOR_PARALLEL_DEGREE": "2", }, + None, ), - (None, None, None, None), + (None, None, None, None, None), ], ) def test_extract_optimization_config_and_env( - quantization_config, compilation_config, expected_config, expected_env + quantization_config, + compilation_config, + expected_config, + expected_quant_env, + expected_compilation_env, ): assert _extract_optimization_config_and_env(quantization_config, compilation_config) == ( expected_config, - expected_env, + expected_quant_env, + expected_compilation_env, + ) + + +@pytest.mark.parametrize( + "deployment_config", + [ + (OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL, True), + (NON_OPTIMIZED_DEPLOYMENT_CONFIG, False), + (None, False), + ], +) +def deployment_config_contains_draft_model(deployment_config, expected): + assert _deployment_config_contains_draft_model(deployment_config) + + +class TestJumpStartSpeculativeDecodingConfig(unittest.TestCase): + + @patch("sagemaker.model.Model") + def test_with_no_js_model_id(self, mock_model): + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = {"ModelSource": "JumpStart"} + + with self.assertRaises(ValueError) as _: + _jumpstart_speculative_decoding(mock_model, speculative_decoding_config) + + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket", + return_value="js_gated_content_bucket", + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket", + return_value="js_content_bucket", + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs", + return_value=Mock(), + ) + @patch("sagemaker.model.Model") + def test_with_gated_js_model( + self, + mock_model, + mock_model_specs, + mock_js_content_bucket, + mock_js_gated_content_bucket, + ): + mock_sagemaker_session = Mock() + mock_sagemaker_session.boto_region_name = "us-west-2" + + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = { + "ModelSource": "JumpStart", + "ModelID": "meta-textgeneration-llama-3-2-1b", + "AcceptEula": True, + } + + mock_model_specs.return_value.to_json.return_value = { + "gated_bucket": True, + "hosting_prepacked_artifact_key": "hosting_prepacked_artifact_key", + } + + _jumpstart_speculative_decoding( + mock_model, speculative_decoding_config, mock_sagemaker_session + ) + + expected_env_var = { + "OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model/" + } + self.maxDiff = None + + self.assertEqual( + mock_model.additional_model_data_sources, + [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": f"s3://{mock_js_gated_content_bucket.return_value}/hosting_prepacked_artifact_key", + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"AcceptEula": True}, + }, + } + ], + ) + + mock_model.add_tags.assert_called_once_with( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"} + ) + self.assertEqual(mock_model.env, expected_env_var) + + @patch( + "sagemaker.serve.utils.optimize_utils.get_eula_message", return_value="Accept eula message" + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket", + return_value="js_gated_content_bucket", ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket", + return_value="js_content_bucket", + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs", + return_value=Mock(), + ) + @patch("sagemaker.model.Model") + def test_with_gated_js_model_and_accept_eula_false( + self, + mock_model, + mock_model_specs, + mock_js_content_bucket, + mock_js_gated_content_bucket, + mock_eula_message, + ): + mock_sagemaker_session = Mock() + mock_sagemaker_session.boto_region_name = "us-west-2" + + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = { + "ModelSource": "JumpStart", + "ModelID": "meta-textgeneration-llama-3-2-1b", + "AcceptEula": False, + } + + mock_model_specs.return_value.to_json.return_value = { + "gated_bucket": True, + "hosting_prepacked_artifact_key": "hosting_prepacked_artifact_key", + } + + self.assertRaisesRegex( + ValueError, + f"{mock_eula_message.return_value} Set `AcceptEula`=True in " + f"speculative_decoding_config once acknowledged.", + _jumpstart_speculative_decoding, + mock_model, + speculative_decoding_config, + mock_sagemaker_session, + ) class TestCustomSpeculativeDecodingConfig(unittest.TestCase): @@ -364,7 +533,7 @@ def test_with_s3_js(self, mock_model): "S3Uri": "s3://bucket/huggingface-pytorch-tgi-inference", "S3DataType": "S3Prefix", "CompressionType": "None", - "ModelAccessConfig": {"ACCEPT_EULA": True}, + "ModelAccessConfig": {"AcceptEula": True}, }, } ],