From eb2cc11a7a07434881ed8ebe792f6e0104939652 Mon Sep 17 00:00:00 2001 From: Ryan Steakley Date: Wed, 2 Oct 2024 20:00:46 +0000 Subject: [PATCH] Fix pylint + doc check --- .../mlflow/forward_sagemaker_metrics.py | 64 ++++++++----------- .../mlflow/test_forward_sagemaker_metrics.py | 8 +-- 2 files changed, 31 insertions(+), 41 deletions(-) diff --git a/src/sagemaker/mlflow/forward_sagemaker_metrics.py b/src/sagemaker/mlflow/forward_sagemaker_metrics.py index 05daaca588..a025241d6a 100644 --- a/src/sagemaker/mlflow/forward_sagemaker_metrics.py +++ b/src/sagemaker/mlflow/forward_sagemaker_metrics.py @@ -11,23 +11,24 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""This module contains code related to forwarding SageMaker TrainingJob Metrics to MLflow.""" + +from __future__ import absolute_import -import boto3 import os +import platform +import re +from typing import Set, Tuple, List, Dict, Generator +import boto3 import mlflow from mlflow import MlflowClient from mlflow.entities import Metric, Param, RunTag from packaging import version -import platform -import re - -from typing import Set, Tuple, List, Dict def encode(name: str, existing_names: Set[str]) -> str: - """ - Encode a string to comply with MLflow naming restrictions and ensure uniqueness. + """Encode a string to comply with MLflow naming restrictions and ensure uniqueness. Args: name (str): The original string to be encoded. @@ -55,7 +56,8 @@ def encode_char(match): if base_name in existing_names: suffix = 1 - # Edge case where even with suffix space there is a collision we will override one of the keys. + # Edge case where even with suffix space there is a collision + # we will override one of the keys. while f"{base_name}_{suffix}" in existing_names: suffix += 1 encoded = f"{base_name}_{suffix}" @@ -68,9 +70,7 @@ def encode_char(match): def decode(encoded_metric_name: str) -> str: - - # TODO: Utilize the stored name mappings to get the original key mappings without having to decode. - """Decodes an encoded metric name by replacing hexadecimal representations with their corresponding characters. + """Decodes an encoded metric name by replacing hexadecimal representations with ASCII This function reverses the encoding process by converting hexadecimal codes back to their original characters. It looks for patterns of the form "_XX_" @@ -100,8 +100,7 @@ def replace_code(match): def get_training_job_details(job_arn: str) -> dict: - """ - Retrieve details of a SageMaker training job. + """Retrieve details of a SageMaker training job. Args: job_arn (str): The ARN of the SageMaker training job. @@ -118,8 +117,7 @@ def get_training_job_details(job_arn: str) -> dict: def create_metric_queries(job_arn: str, metric_definitions: list) -> list: - """ - Create metric queries for SageMaker metrics. + """Create metric queries for SageMaker metrics. Args: job_arn (str): The ARN of the SageMaker training job. @@ -142,8 +140,7 @@ def create_metric_queries(job_arn: str, metric_definitions: list) -> list: def get_metric_data(metric_queries: list) -> dict: - """ - Retrieve metric data from SageMaker. + """Retrieve metric data from SageMaker. Args: metric_queries (list): A list of metric queries. @@ -162,8 +159,7 @@ def get_metric_data(metric_queries: list) -> dict: def prepare_mlflow_metrics( metric_queries: list, metric_results: list ) -> Tuple[List[Metric], Dict[str, str]]: - """ - Prepare metrics for MLflow logging, encoding metric names if necessary. + """Prepare metrics for MLflow logging, encoding metric names if necessary. Args: metric_queries (list): The original metric queries sent to SageMaker. @@ -184,21 +180,17 @@ def prepare_mlflow_metrics( encoded_name = encode(metric_name, existing_names) metric_name_mapping[encoded_name] = metric_name - mlflow_metrics.extend( - [ - Metric(key=encoded_name, value=value, timestamp=timestamp, step=step) - for step, (timestamp, value) in enumerate( - zip(result["XAxisValues"], result["MetricValues"]) - ) - ] - ) + for step, (timestamp, value) in enumerate( + zip(result["XAxisValues"], result["MetricValues"]) + ): + metric = Metric(key=encoded_name, value=value, timestamp=timestamp, step=step) + mlflow_metrics.append(metric) return mlflow_metrics, metric_name_mapping def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param], Dict[str, str]]: - """ - Prepare hyperparameters for MLflow logging, encoding parameter names if necessary. + """Prepare hyperparameters for MLflow logging, encoding parameter names if necessary. Args: hyperparameters (Dict[str, str]): The hyperparameters from the SageMaker job. @@ -206,7 +198,8 @@ def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param], Returns: Tuple[List[Param], Dict[str, str]]: - A list of Param objects with encoded names (if necessary) - - A mapping of encoded to original names for hyperparameters (only for encoded parameters) + - A mapping of encoded to original names for + hyperparameters (only for encoded parameters) """ mlflow_params = [] param_name_mapping = {} @@ -220,9 +213,8 @@ def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param], return mlflow_params, param_name_mapping -def batch_items(items: list, batch_size: int) -> list: - """ - Yield successive batch_size chunks from items. +def batch_items(items: list, batch_size: int) -> Generator: + """Yield successive batch_size chunks from items. Args: items (list): The list of items to be batched. @@ -236,8 +228,7 @@ def batch_items(items: list, batch_size: int) -> list: def log_to_mlflow(metrics: list, params: list, tags: dict) -> None: - """ - Log metrics, parameters, and tags to MLflow. + """Log metrics, parameters, and tags to MLflow. Args: metrics (list): List of metrics to log. @@ -278,8 +269,7 @@ def log_to_mlflow(metrics: list, params: list, tags: dict) -> None: def log_sagemaker_job_to_mlflow(training_job_arn: str) -> None: - """ - Retrieve SageMaker metrics and hyperparameters and log them to MLflow. + """Retrieve SageMaker metrics and hyperparameters and log them to MLflow. Args: training_job_arn (str): The ARN of the SageMaker training job. diff --git a/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py b/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py index d39e37aa0a..4b53c93ad4 100644 --- a/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py +++ b/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py @@ -11,12 +11,13 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from __future__ import absolute_import +from unittest.mock import patch, MagicMock, Mock +import json import pytest from mlflow.entities import Metric, Param -from unittest.mock import patch, MagicMock, Mock import requests -import json from sagemaker.mlflow.forward_sagemaker_metrics import ( encode, @@ -161,10 +162,9 @@ def test_batch_items(): assert batches == [[1, 2], [3, 4], [5]] -@patch("mlflow.MlflowClient") @patch("os.getenv") @patch("requests.Session.request") -def test_log_to_mlflow(mock_request, mock_getenv, mock_mlflow_client): +def test_log_to_mlflow(mock_request, mock_getenv): # Set up return values for os.getenv calls def getenv_side_effect(arg, default=None): values = {