Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Enabled update_endpoint through model_builder #5085

Merged
merged 22 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
25f16ef
change: Allow telemetry only in supported regions
Jan 29, 2025
0ed85d6
change: Allow telemetry only in supported regions
Jan 29, 2025
b69ffcb
change: Allow telemetry only in supported regions
Jan 29, 2025
8d7f4a8
change: Allow telemetry only in supported regions
Jan 29, 2025
9321367
Merge branch 'aws:master' into rsareddy-dev
rsareddy0329 Jan 29, 2025
f972222
Merge branch 'aws:master' into rsareddy-dev
rsareddy0329 Jan 30, 2025
dadbb22
change: Allow telemetry only in supported regions
Jan 30, 2025
28b3fe8
Merge branch 'aws:master' into rsareddy-dev
rsareddy0329 Feb 23, 2025
fe64f82
Merge branch 'aws:master' into rsareddy-dev
rsareddy0329 Feb 24, 2025
7775c63
documentation: Removed a line about python version requirements of tr…
Feb 24, 2025
acc861a
Merge branch 'master' into rsareddy-dev
rsareddy0329 Feb 24, 2025
16dc02b
Merge branch 'aws:master' into rsareddy-dev
rsareddy0329 Mar 10, 2025
06597c6
Merge branch 'aws:master' into rsareddy-dev
rsareddy0329 Mar 11, 2025
249872d
Merge branch 'aws:master' into rsareddy-dev
rsareddy0329 Mar 12, 2025
58f8746
feature: Enabled update_endpoint through model_builder
Mar 12, 2025
c6bad70
Merge branch 'aws:master' into rsareddy-dev
rsareddy0329 Mar 12, 2025
0bf6404
fix: fix unit test, black-check, pylint errors
Mar 12, 2025
c67d7df
fix: fix black-check, pylint errors
Mar 12, 2025
1f84662
Merge branch 'aws:master' into rsareddy-dev
rsareddy0329 Mar 13, 2025
ea1810b
Merge branch 'aws:master' into rsareddy-dev
rsareddy0329 Mar 14, 2025
6079269
Merge branch 'aws:master' into rsareddy-dev
rsareddy0329 Mar 14, 2025
c9fcefb
Merge branch 'aws:master' into rsareddy-dev
rsareddy0329 Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def deploy(
container_startup_health_check_timeout=None,
inference_recommendation_id=None,
explainer_config=None,
update_endpoint: Optional[bool] = False,
**kwargs,
):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
Expand Down Expand Up @@ -296,6 +297,11 @@ def deploy(
would like to deploy the model and endpoint with recommended parameters.
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
configuration for use with Amazon SageMaker Clarify. (default: None)
update_endpoint (Optional[bool]):
Flag to update the model in an existing Amazon SageMaker endpoint.
If True, this will deploy a new EndpointConfig to an already existing endpoint
and delete resources corresponding to the previous EndpointConfig. Default: False
Note: Currently this is supported for single model endpoints
Raises:
ValueError: If arguments combination check failed in these circumstances:
- If no role is specified or
Expand Down Expand Up @@ -335,6 +341,7 @@ def deploy(
container_startup_health_check_timeout=container_startup_health_check_timeout,
inference_recommendation_id=inference_recommendation_id,
explainer_config=explainer_config,
update_endpoint=update_endpoint,
**kwargs,
)

Expand Down
56 changes: 42 additions & 14 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
from sagemaker.session import Session
from sagemaker.model_metrics import ModelMetrics
from sagemaker.deprecations import removed_kwargs
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.explainer import ExplainerConfig
from sagemaker.metadata_properties import MetadataProperties
Expand Down Expand Up @@ -1386,6 +1385,7 @@ def deploy(
routing_config: Optional[Dict[str, Any]] = None,
model_reference_arn: Optional[str] = None,
inference_ami_version: Optional[str] = None,
update_endpoint: Optional[bool] = False,
**kwargs,
):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
Expand Down Expand Up @@ -1497,6 +1497,11 @@ def deploy(
inference_ami_version (Optional [str]): Specifies an option from a collection of preconfigured
Amazon Machine Image (AMI) images. For a full list of options, see:
https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html
update_endpoint (Optional[bool]):
Flag to update the model in an existing Amazon SageMaker endpoint.
If True, this will deploy a new EndpointConfig to an already existing endpoint
and delete resources corresponding to the previous EndpointConfig. Default: False
Note: Currently this is supported for single model endpoints
Raises:
ValueError: If arguments combination check failed in these circumstances:
- If no role is specified or
Expand All @@ -1512,8 +1517,6 @@ def deploy(
"""
self.accept_eula = accept_eula

removed_kwargs("update_endpoint", kwargs)

self._init_sagemaker_session_if_does_not_exist(instance_type)
# Depending on the instance type, a local session (or) a session is initialized.
self.role = resolve_value_from_config(
Expand Down Expand Up @@ -1628,6 +1631,10 @@ def deploy(

# Support multiple models on same endpoint
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
if update_endpoint:
raise ValueError(
"Currently update_endpoint is supported for single model endpoints"
)
if endpoint_name:
self.endpoint_name = endpoint_name
else:
Expand Down Expand Up @@ -1783,17 +1790,38 @@ def deploy(
if is_explainer_enabled:
explainer_config_dict = explainer_config._to_request_dict()

self.sagemaker_session.endpoint_from_production_variants(
name=self.endpoint_name,
production_variants=[production_variant],
tags=tags,
kms_key=kms_key,
wait=wait,
data_capture_config_dict=data_capture_config_dict,
explainer_config_dict=explainer_config_dict,
async_inference_config_dict=async_inference_config_dict,
live_logging=endpoint_logging,
)
if update_endpoint:
endpoint_config_name = self.sagemaker_session.create_endpoint_config(
name=self.name,
model_name=self.name,
initial_instance_count=initial_instance_count,
instance_type=instance_type,
accelerator_type=accelerator_type,
tags=tags,
kms_key=kms_key,
data_capture_config_dict=data_capture_config_dict,
volume_size=volume_size,
model_data_download_timeout=model_data_download_timeout,
container_startup_health_check_timeout=container_startup_health_check_timeout,
explainer_config_dict=explainer_config_dict,
async_inference_config_dict=async_inference_config_dict,
serverless_inference_config=serverless_inference_config_dict,
routing_config=routing_config,
inference_ami_version=inference_ami_version,
)
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
else:
self.sagemaker_session.endpoint_from_production_variants(
name=self.endpoint_name,
production_variants=[production_variant],
tags=tags,
kms_key=kms_key,
wait=wait,
data_capture_config_dict=data_capture_config_dict,
explainer_config_dict=explainer_config_dict,
async_inference_config_dict=async_inference_config_dict,
live_logging=endpoint_logging,
)

if self.predictor_cls:
predictor = self.predictor_cls(self.endpoint_name, self.sagemaker_session)
Expand Down
18 changes: 17 additions & 1 deletion src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,6 +1602,7 @@ def deploy(
ResourceRequirements,
]
] = None,
update_endpoint: Optional[bool] = False,
) -> Union[Predictor, Transformer]:
"""Deploys the built Model.

Expand All @@ -1615,24 +1616,33 @@ def deploy(
AsyncInferenceConfig, BatchTransformInferenceConfig, ResourceRequirements]]) :
Additional Config for different deployment types such as
serverless, async, batch and multi-model/container
update_endpoint (Optional[bool]):
Flag to update the model in an existing Amazon SageMaker endpoint.
If True, this will deploy a new EndpointConfig to an already existing endpoint
and delete resources corresponding to the previous EndpointConfig. Default: False
Note: Currently this is supported for single model endpoints
Returns:
Transformer for Batch Deployments
Predictors for all others
"""
if not hasattr(self, "built_model"):
raise ValueError("Model Needs to be built before deploying")
endpoint_name = unique_name_from_base(endpoint_name)
if not update_endpoint:
endpoint_name = unique_name_from_base(endpoint_name)

if not inference_config: # Real-time Deployment
return self.built_model.deploy(
instance_type=self.instance_type,
initial_instance_count=initial_instance_count,
endpoint_name=endpoint_name,
update_endpoint=update_endpoint,
)

if isinstance(inference_config, ServerlessInferenceConfig):
return self.built_model.deploy(
serverless_inference_config=inference_config,
endpoint_name=endpoint_name,
update_endpoint=update_endpoint,
)

if isinstance(inference_config, AsyncInferenceConfig):
Expand All @@ -1641,6 +1651,7 @@ def deploy(
initial_instance_count=initial_instance_count,
async_inference_config=inference_config,
endpoint_name=endpoint_name,
update_endpoint=update_endpoint,
)

if isinstance(inference_config, BatchTransformInferenceConfig):
Expand All @@ -1652,6 +1663,10 @@ def deploy(
return transformer

if isinstance(inference_config, ResourceRequirements):
if update_endpoint:
raise ValueError(
"Currently update_endpoint is supported for single model endpoints"
)
# Multi Model and MultiContainer endpoints with Inference Component
return self.built_model.deploy(
instance_type=self.instance_type,
Expand All @@ -1660,6 +1675,7 @@ def deploy(
resources=inference_config,
initial_instance_count=initial_instance_count,
role=self.role_arn,
update_endpoint=update_endpoint,
)

raise ValueError("Deployment Options not supported")
Expand Down
39 changes: 39 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4488,6 +4488,10 @@ def create_endpoint_config(
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
explainer_config_dict=None,
async_inference_config_dict=None,
serverless_inference_config_dict=None,
routing_config: Optional[Dict[str, Any]] = None,
inference_ami_version: Optional[str] = None,
):
"""Create an Amazon SageMaker endpoint configuration.

Expand Down Expand Up @@ -4525,6 +4529,30 @@ def create_endpoint_config(
-inference-algo-ping-requests
explainer_config_dict (dict): Specifies configuration to enable explainers.
Default: None.
async_inference_config_dict (dict): Specifies
configuration related to async endpoint. Use this configuration when trying
to create async endpoint and make async inference. If empty config object
passed through, will use default config to deploy async endpoint. Deploy a
real-time endpoint if it's None. (default: None).
serverless_inference_config_dict (dict):
Specifies configuration related to serverless endpoint. Use this configuration
when trying to create serverless endpoint and make serverless inference. If
empty object passed through, will use pre-defined values in
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
instance based endpoint if it's None. (default: None).
routing_config (Optional[Dict[str, Any]): Settings the control how the endpoint routes
incoming traffic to the instances that the endpoint hosts.
Currently, support dictionary key ``RoutingStrategy``.

.. code:: python

{
"RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM
}
inference_ami_version (Optional [str]):
Specifies an option from a collection of preconfigured
Amazon Machine Image (AMI) images. For a full list of options, see:
https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html

Example:
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
Expand All @@ -4544,9 +4572,12 @@ def create_endpoint_config(
instance_type,
initial_instance_count,
accelerator_type=accelerator_type,
serverless_inference_config=serverless_inference_config_dict,
volume_size=volume_size,
model_data_download_timeout=model_data_download_timeout,
container_startup_health_check_timeout=container_startup_health_check_timeout,
routing_config=routing_config,
inference_ami_version=inference_ami_version,
)
production_variants = [provided_production_variant]
# Currently we just inject CoreDumpConfig.KmsKeyId from the config for production variant.
Expand Down Expand Up @@ -4586,6 +4617,14 @@ def create_endpoint_config(
)
request["DataCaptureConfig"] = inferred_data_capture_config_dict

if async_inference_config_dict is not None:
inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config(
async_inference_config_dict,
ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
sagemaker_session=self,
)
request["AsyncInferenceConfig"] = inferred_async_inference_config_dict

if explainer_config_dict is not None:
request["ExplainerConfig"] = explainer_config_dict

Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def deploy(
container_startup_health_check_timeout=None,
inference_recommendation_id=None,
explainer_config=None,
update_endpoint: Optional[bool] = False,
**kwargs,
):
"""Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``."""
Expand All @@ -383,6 +384,7 @@ def deploy(
container_startup_health_check_timeout=container_startup_health_check_timeout,
inference_recommendation_id=inference_recommendation_id,
explainer_config=explainer_config,
update_endpoint=update_endpoint,
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
and reach out to JumpStart team."""

init_args_to_skip: Set[str] = set(["model_reference_arn"])
deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn"])
deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn", "update_endpoint"])
deploy_args_removed_at_deploy_time: Set[str] = set(["model_access_configs"])

parent_class_init = Model.__init__
Expand Down
Loading
Loading