Skip to content

Commit

Permalink
Merge branch 'master' into test-tf-2-16-inf
Browse files Browse the repository at this point in the history
  • Loading branch information
shantanutrip authored Aug 22, 2024
2 parents 8493136 + e240518 commit 3e55e7d
Show file tree
Hide file tree
Showing 28 changed files with 516 additions and 469 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ max-returns=6
max-branches=12

# Maximum number of statements in function / method body
max-statements=100
max-statements=105

# Maximum number of parents for a class (see R0901).
max-parents=7
Expand Down
14 changes: 14 additions & 0 deletions src/sagemaker/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,20 @@ def __init__(
available (default: ``None``).
**kwargs: Additional kwargs. This is unused. It's only added for AlgorithmEstimator
to ignore the irrelevant arguments.
Raises:
ValueError:
- If an AWS IAM Role is not provided.
- Bad value for instance type.
RuntimeError:
- When setting up custom VPC, both subnets and security_group_ids are not provided
- If instance_count > 1 (distributed training) with instance type local or local gpu
- If LocalSession is not used with instance type local or local gpu
- file:// output path used outside of local mode
botocore.exceptions.ClientError:
- algorithm arn is incorrect
- insufficient permission to access/ describe algorithm
- algorithm is in a different region
"""
self.algorithm_arn = algorithm_arn
super(AlgorithmEstimator, self).__init__(
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ def update_endpoint(
- If ``initial_instance_count``, ``instance_type``, or ``accelerator_type`` is
specified and either ``model_name`` is ``None`` or there are multiple models
associated with the endpoint.
botocore.exceptions.ClientError: If SageMaker throws an error while creating
endpoint config, describing endpoint or updating endpoint
"""
production_variants = None
current_model_names = self._get_model_names()
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.enums import JumpStartModelType, JumpStartScriptScope
from sagemaker.session import Session

logger = logging.getLogger(__name__)
Expand All @@ -38,6 +38,7 @@ def retrieve_default(
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> Dict[str, str]:
"""Retrieves the default container environment variables for the model matching the arguments.
Expand Down Expand Up @@ -70,6 +71,8 @@ def retrieve_default(
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
variables.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
dict: The variables to use for the model.
Expand All @@ -94,4 +97,5 @@ def retrieve_default(
instance_type=instance_type,
script=script,
config_name=config_name,
model_type=model_type,
)
32 changes: 27 additions & 5 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,25 +590,36 @@ def __init__(
self.dependencies = dependencies or []
self.uploaded_code: Optional[UploadedCode] = None

# Check that the user properly sets both subnet and secutiry_groupe_ids
# Check that the user properly sets both subnet and security_group_ids
if (
subnets is not None
and security_group_ids is None
or security_group_ids is not None
and subnets is None
):
troubleshooting = (
"Refer to this documentation on using custom VPC: "
"https://sagemaker.readthedocs.io/en/v2.24.0/overview.html"
"#secure-training-and-inference-with-vpc"
)
logger.error("Check troubleshooting guide for common errors: %s", troubleshooting)

raise RuntimeError(
"When setting up custom VPC, both subnets and security_group_ids must be set"
)

if self.instance_type in ("local", "local_gpu"):
if self.instance_type == "local_gpu" and self.instance_count > 1:
raise RuntimeError("Distributed Training in Local GPU is not supported")
raise RuntimeError(
"Distributed Training in Local GPU is not supported."
" Set instance_count to 1."
)
self.sagemaker_session = sagemaker_session or LocalSession()
if not isinstance(self.sagemaker_session, sagemaker.local.LocalSession):
raise RuntimeError(
"instance_type local or local_gpu is only supported with an"
"instance of LocalSession"
"instance of LocalSession. More details on local mode: "
"https://sagemaker.readthedocs.io/en/stable/overview.html#local-mode"
)
else:
self.sagemaker_session = sagemaker_session or Session()
Expand All @@ -631,7 +642,11 @@ def __init__(
and not is_pipeline_variable(output_path)
and output_path.startswith("file://")
):
raise RuntimeError("file:// output paths are only supported in Local Mode")
raise RuntimeError(
"The 'file://' output paths are only supported when using Local Mode. "
"To resolve this issue, ensure you're running in Local Mode with a LocalSession, "
"or use an 's3://' output path for jobs running on SageMaker instances."
)
self.output_path = output_path
self.latest_training_job = None
self.jobs = []
Expand All @@ -646,7 +661,12 @@ def __init__(
# Now we marked that as Optional because we can fetch it from SageMakerConfig
# Because of marking that parameter as optional, we should validate if it is None, even
# after fetching the config.
raise ValueError("An AWS IAM role is required to create an estimator.")
raise ValueError(
"An AWS IAM role is required to create an estimator. "
"Please provide a valid `role` argument with the ARN of an IAM role"
" that has the necessary SageMaker permissions."
)

self.output_kms_key = resolve_value_from_config(
output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
)
Expand Down Expand Up @@ -1855,6 +1875,8 @@ def model_data(self):
if compression_type not in {"GZIP", "NONE"}:
raise ValueError(
f'Unrecognized training job output data compression type "{compression_type}"'
'. Please specify either "GZIP" or "NONE" as valid options for '
"the compression type."
)
# model data is in uncompressed form NOTE SageMaker Hosting mandates presence of
# trailing forward slash in S3 model data URI, so append one if necessary.
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import HyperparameterValidationMode
from sagemaker.jumpstart.enums import HyperparameterValidationMode, JumpStartModelType
from sagemaker.jumpstart.validators import validate_hyperparameters
from sagemaker.session import Session

Expand All @@ -38,6 +38,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> Dict[str, str]:
"""Retrieves the default training hyperparameters for the model matching the given arguments.
Expand Down Expand Up @@ -71,6 +72,8 @@ def retrieve_default(
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
dict: The hyperparameters to use for the model.
Expand All @@ -93,6 +96,7 @@ def retrieve_default(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)


Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from sagemaker import utils
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart.utils import is_jumpstart_model_input
from sagemaker.spark import defaults
from sagemaker.jumpstart import artifacts
Expand Down Expand Up @@ -72,6 +73,7 @@ def retrieve(
serverless_inference_config=None,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name=None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.
Expand Down Expand Up @@ -128,6 +130,8 @@ def retrieve(
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
str: The ECR URI for the corresponding SageMaker Docker image.
Expand Down Expand Up @@ -169,6 +173,7 @@ def retrieve(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):
Expand Down
11 changes: 10 additions & 1 deletion src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
)
from sagemaker.jumpstart.enums import (
JumpStartModelType,
JumpStartScriptScope,
)
from sagemaker.jumpstart.utils import (
Expand All @@ -41,6 +42,7 @@ def _retrieve_default_environment_variables(
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> Dict[str, str]:
"""Retrieves the inference environment variables for the model matching the given arguments.
Expand Down Expand Up @@ -73,6 +75,8 @@ def _retrieve_default_environment_variables(
script (JumpStartScriptScope): The JumpStart script for which to retrieve
environment variables.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
dict: the inference environment variables to use for the model.
"""
Expand All @@ -91,6 +95,7 @@ def _retrieve_default_environment_variables(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

default_environment_variables: Dict[str, str] = {}
Expand Down Expand Up @@ -130,6 +135,7 @@ def _retrieve_default_environment_variables(
sagemaker_session=sagemaker_session,
instance_type=instance_type,
config_name=config_name,
model_type=model_type,
)
)

Expand Down Expand Up @@ -178,6 +184,7 @@ def _retrieve_gated_model_uri_env_var_value(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> Optional[str]:
"""Retrieves the gated model env var URI matching the given arguments.
Expand All @@ -204,7 +211,8 @@ def _retrieve_gated_model_uri_env_var_value(
instance_type (str): An instance type to optionally supply in order to get
environment variables specific for the instance type.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
have gated training artifacts.
Expand All @@ -227,6 +235,7 @@ def _retrieve_gated_model_uri_env_var_value(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

s3_key: Optional[str] = (
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/artifacts/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.enums import (
JumpStartModelType,
JumpStartScriptScope,
VariableScope,
)
Expand All @@ -38,6 +39,7 @@ def _retrieve_default_hyperparameters(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
):
"""Retrieves the training hyperparameters for the model matching the given arguments.
Expand Down Expand Up @@ -71,6 +73,8 @@ def _retrieve_default_hyperparameters(
instance_type (str): An instance type to optionally supply in order to get hyperparameters
specific for the instance type.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
dict: the hyperparameters to use for the model.
"""
Expand All @@ -89,6 +93,7 @@ def _retrieve_default_hyperparameters(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

default_hyperparameters: Dict[str, str] = {}
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.enums import (
JumpStartModelType,
JumpStartScriptScope,
ModelFramework,
)
Expand Down Expand Up @@ -48,6 +49,7 @@ def _retrieve_image_uri(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
):
"""Retrieves the container image URI for JumpStart models.
Expand Down Expand Up @@ -100,6 +102,8 @@ def _retrieve_image_uri(
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
str: the ECR URI for the corresponding SageMaker Docker image.
Expand All @@ -123,6 +127,7 @@ def _retrieve_image_uri(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

if image_scope == JumpStartScriptScope.INFERENCE:
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/artifacts/incremental_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.enums import (
JumpStartModelType,
JumpStartScriptScope,
)
from sagemaker.jumpstart.utils import (
Expand All @@ -35,6 +36,7 @@ def _model_supports_incremental_training(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> bool:
"""Returns True if the model supports incremental training.
Expand All @@ -59,6 +61,8 @@ def _model_supports_incremental_training(
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
bool: the support status for incremental training.
"""
Expand All @@ -77,6 +81,7 @@ def _model_supports_incremental_training(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

return model_specs.supports_incremental_training()
Loading

0 comments on commit 3e55e7d

Please sign in to comment.