Skip to content

Commit

Permalink
Refactor Sagemaker orchestrator metadata handling
Browse files Browse the repository at this point in the history
  • Loading branch information
htahir1 committed Dec 20, 2024
1 parent abf4610 commit 67705c2
Showing 1 changed file with 99 additions and 43 deletions.
142 changes: 99 additions & 43 deletions src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def prepare_or_run_pipeline(
f"Creating EventBridge rule with schedule expression: {schedule_expr}"
)

# Create IAM policy and trust relationship for EventBridge
# Create IAM policy for EventBridge
iam_client = session.boto_session.client("iam")
role_name = self.config.execution_role.split("/")[
-1
Expand Down Expand Up @@ -633,12 +633,17 @@ def prepare_or_run_pipeline(
else "one-time"
)

yield from self.compute_schedule_metadata(
rule_name=rule_name,
schedule_expr=schedule_expr,
pipeline_name=orchestrator_run_name,
next_execution=next_execution,
schedule_type=schedule_type,
schedule_metadata = {
"rule_name": rule_name,
"schedule_type": schedule_type,
"schedule_expr": schedule_expr,
"pipeline_name": orchestrator_run_name,
"next_execution": next_execution,
}

yield from self.compute_metadata(
execution=schedule_metadata,
settings=settings,
)
else:
# Execute the pipeline immediately if no schedule is specified
Expand Down Expand Up @@ -757,7 +762,7 @@ def compute_metadata(
"""Generate run metadata based on the generated Sagemaker Execution.
Args:
execution: The corresponding _PipelineExecution object.
execution: The corresponding _PipelineExecution object or schedule metadata dict.
settings: The Sagemaker orchestrator settings.
Yields:
Expand All @@ -766,19 +771,40 @@ def compute_metadata(
# Metadata
metadata: Dict[str, MetadataType] = {}

# Orchestrator Run ID
if run_id := self._compute_orchestrator_run_id(execution):
metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id
# Handle schedule metadata if execution is a dict
if isinstance(execution, dict):
metadata.update(
{
"schedule_rule_name": execution["rule_name"],
"schedule_type": execution["schedule_type"],
"schedule_expression": execution["schedule_expr"],
"pipeline_name": execution["pipeline_name"],
}
)

if next_execution := execution.get("next_execution"):
metadata["next_execution_time"] = next_execution.isoformat()

# Add orchestrator metadata using the same pattern as execution metadata
if orchestrator_url := self._compute_schedule_url(execution):
metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)

# URL to the Sagemaker's pipeline view
if orchestrator_url := self._compute_orchestrator_url(execution):
metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)
if logs_url := self._compute_schedule_logs_url(
execution, settings
):
metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)
else:
# Handle execution metadata
if run_id := self._compute_orchestrator_run_id(execution):
metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id

if orchestrator_url := self._compute_orchestrator_url(execution):
metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)

# URL to the corresponding CloudWatch page
if logs_url := self._compute_orchestrator_logs_url(
execution, settings
):
metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)
if logs_url := self._compute_orchestrator_logs_url(
execution, settings
):
metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)

yield metadata

Expand Down Expand Up @@ -876,34 +902,64 @@ def _compute_orchestrator_run_id(
)
return None

def compute_schedule_metadata(
self,
rule_name: str,
schedule_expr: str,
pipeline_name: str,
next_execution: Optional[datetime],
schedule_type: str,
) -> Dict[str, MetadataType]:
"""Generate metadata for scheduled pipeline executions.
@staticmethod
def _compute_schedule_url(schedule_info: Dict[str, Any]) -> Optional[str]:
"""Generate the SageMaker Console URL for a scheduled pipeline.
Args:
rule_name: The name of the EventBridge rule
schedule_expr: The schedule expression (cron or rate)
pipeline_name: Name of the SageMaker pipeline
next_execution: Next scheduled execution time
schedule_type: Type of schedule (cron/rate/one-time)
schedule_info: Dictionary containing schedule information.
Returns:
A dictionary of metadata related to the schedule.
The URL to the pipeline in the SageMaker console.
"""
metadata: Dict[str, MetadataType] = {
"schedule_rule_name": rule_name,
"schedule_type": schedule_type,
"schedule_expression": schedule_expr,
"pipeline_name": pipeline_name,
}
try:
# Get the Sagemaker session
session = boto3.Session(region_name=schedule_info["region"])
sagemaker_client = session.client("sagemaker")

# List the Studio domains and get the Studio Domain ID
domains_response = sagemaker_client.list_domains()
studio_domain_id = domains_response["Domains"][0]["DomainId"]

return (
f"https://studio-{studio_domain_id}.studio.{schedule_info['region']}."
f"sagemaker.aws/pipelines/view/{schedule_info['pipeline_name']}"
)
except Exception as e:
logger.warning(
f"There was an issue while extracting the pipeline url: {e}"
)
return None

@staticmethod
def _compute_schedule_logs_url(
schedule_info: Dict[str, Any],
settings: SagemakerOrchestratorSettings,
) -> Optional[str]:
"""Generate the CloudWatch URL for a scheduled pipeline.
if next_execution:
metadata["next_execution_time"] = next_execution.isoformat()
Args:
schedule_info: Dictionary containing schedule information.
settings: The Sagemaker orchestrator settings.
return metadata
Returns:
The URL to query the pipeline logs in CloudWatch.
"""
try:
use_training_jobs = True
if settings.use_training_step is not None:
use_training_jobs = settings.use_training_step

job_type = "Training" if use_training_jobs else "Processing"

return (
f"https://{schedule_info['region']}.console.aws.amazon.com/"
f"cloudwatch/home?region={schedule_info['region']}#logsV2:"
f"log-groups/log-group/$252Faws$252Fsagemaker$252F{job_type}Jobs"
f"$3FlogStreamNameFilter$3Dpipelines-"
)
except Exception as e:
logger.warning(
f"There was an issue while extracting the logs url: {e}"
)
return None

0 comments on commit 67705c2

Please sign in to comment.