Skip to content

Commit

Permalink
Fix pylint + doc check
Browse files Browse the repository at this point in the history
  • Loading branch information
ryansteakley committed Oct 2, 2024
1 parent 99b7633 commit eb2cc11
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 41 deletions.
64 changes: 27 additions & 37 deletions src/sagemaker/mlflow/forward_sagemaker_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,24 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""This module contains code related to forwarding SageMaker TrainingJob Metrics to MLflow."""

from __future__ import absolute_import

import boto3
import os
import platform
import re
from typing import Set, Tuple, List, Dict, Generator
import boto3
import mlflow
from mlflow import MlflowClient
from mlflow.entities import Metric, Param, RunTag

from packaging import version
import platform
import re

from typing import Set, Tuple, List, Dict


def encode(name: str, existing_names: Set[str]) -> str:
"""
Encode a string to comply with MLflow naming restrictions and ensure uniqueness.
"""Encode a string to comply with MLflow naming restrictions and ensure uniqueness.
Args:
name (str): The original string to be encoded.
Expand Down Expand Up @@ -55,7 +56,8 @@ def encode_char(match):

if base_name in existing_names:
suffix = 1
# Edge case where even with suffix space there is a collision we will override one of the keys.
# Edge case where even with suffix space there is a collision
# we will override one of the keys.
while f"{base_name}_{suffix}" in existing_names:
suffix += 1
encoded = f"{base_name}_{suffix}"
Expand All @@ -68,9 +70,7 @@ def encode_char(match):


def decode(encoded_metric_name: str) -> str:

# TODO: Utilize the stored name mappings to get the original key mappings without having to decode.
"""Decodes an encoded metric name by replacing hexadecimal representations with their corresponding characters.
"""Decodes an encoded metric name by replacing hexadecimal representations with ASCII
This function reverses the encoding process by converting hexadecimal codes
back to their original characters. It looks for patterns of the form "_XX_"
Expand Down Expand Up @@ -100,8 +100,7 @@ def replace_code(match):


def get_training_job_details(job_arn: str) -> dict:
"""
Retrieve details of a SageMaker training job.
"""Retrieve details of a SageMaker training job.
Args:
job_arn (str): The ARN of the SageMaker training job.
Expand All @@ -118,8 +117,7 @@ def get_training_job_details(job_arn: str) -> dict:


def create_metric_queries(job_arn: str, metric_definitions: list) -> list:
"""
Create metric queries for SageMaker metrics.
"""Create metric queries for SageMaker metrics.
Args:
job_arn (str): The ARN of the SageMaker training job.
Expand All @@ -142,8 +140,7 @@ def create_metric_queries(job_arn: str, metric_definitions: list) -> list:


def get_metric_data(metric_queries: list) -> dict:
"""
Retrieve metric data from SageMaker.
"""Retrieve metric data from SageMaker.
Args:
metric_queries (list): A list of metric queries.
Expand All @@ -162,8 +159,7 @@ def get_metric_data(metric_queries: list) -> dict:
def prepare_mlflow_metrics(
metric_queries: list, metric_results: list
) -> Tuple[List[Metric], Dict[str, str]]:
"""
Prepare metrics for MLflow logging, encoding metric names if necessary.
"""Prepare metrics for MLflow logging, encoding metric names if necessary.
Args:
metric_queries (list): The original metric queries sent to SageMaker.
Expand All @@ -184,29 +180,26 @@ def prepare_mlflow_metrics(
encoded_name = encode(metric_name, existing_names)
metric_name_mapping[encoded_name] = metric_name

mlflow_metrics.extend(
[
Metric(key=encoded_name, value=value, timestamp=timestamp, step=step)
for step, (timestamp, value) in enumerate(
zip(result["XAxisValues"], result["MetricValues"])
)
]
)
for step, (timestamp, value) in enumerate(
zip(result["XAxisValues"], result["MetricValues"])
):
metric = Metric(key=encoded_name, value=value, timestamp=timestamp, step=step)
mlflow_metrics.append(metric)

return mlflow_metrics, metric_name_mapping


def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param], Dict[str, str]]:
"""
Prepare hyperparameters for MLflow logging, encoding parameter names if necessary.
"""Prepare hyperparameters for MLflow logging, encoding parameter names if necessary.
Args:
hyperparameters (Dict[str, str]): The hyperparameters from the SageMaker job.
Returns:
Tuple[List[Param], Dict[str, str]]:
- A list of Param objects with encoded names (if necessary)
- A mapping of encoded to original names for hyperparameters (only for encoded parameters)
- A mapping of encoded to original names for
hyperparameters (only for encoded parameters)
"""
mlflow_params = []
param_name_mapping = {}
Expand All @@ -220,9 +213,8 @@ def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param],
return mlflow_params, param_name_mapping


def batch_items(items: list, batch_size: int) -> list:
"""
Yield successive batch_size chunks from items.
def batch_items(items: list, batch_size: int) -> Generator:
"""Yield successive batch_size chunks from items.
Args:
items (list): The list of items to be batched.
Expand All @@ -236,8 +228,7 @@ def batch_items(items: list, batch_size: int) -> list:


def log_to_mlflow(metrics: list, params: list, tags: dict) -> None:
"""
Log metrics, parameters, and tags to MLflow.
"""Log metrics, parameters, and tags to MLflow.
Args:
metrics (list): List of metrics to log.
Expand Down Expand Up @@ -278,8 +269,7 @@ def log_to_mlflow(metrics: list, params: list, tags: dict) -> None:


def log_sagemaker_job_to_mlflow(training_job_arn: str) -> None:
"""
Retrieve SageMaker metrics and hyperparameters and log them to MLflow.
"""Retrieve SageMaker metrics and hyperparameters and log them to MLflow.
Args:
training_job_arn (str): The ARN of the SageMaker training job.
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import absolute_import
from unittest.mock import patch, MagicMock, Mock
import json
import pytest
from mlflow.entities import Metric, Param
from unittest.mock import patch, MagicMock, Mock
import requests

import json

from sagemaker.mlflow.forward_sagemaker_metrics import (
encode,
Expand Down Expand Up @@ -161,10 +162,9 @@ def test_batch_items():
assert batches == [[1, 2], [3, 4], [5]]


@patch("mlflow.MlflowClient")
@patch("os.getenv")
@patch("requests.Session.request")
def test_log_to_mlflow(mock_request, mock_getenv, mock_mlflow_client):
def test_log_to_mlflow(mock_request, mock_getenv):
# Set up return values for os.getenv calls
def getenv_side_effect(arg, default=None):
values = {
Expand Down

0 comments on commit eb2cc11

Please sign in to comment.