diff --git a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py index 72bfa14c9e..1abae0806d 100644 --- a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +++ b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py @@ -521,7 +521,7 @@ def prepare_or_run_pipeline( f"Creating EventBridge rule with schedule expression: {schedule_expr}" ) - # Create IAM policy and trust relationship for EventBridge + # Create IAM policy for EventBridge iam_client = session.boto_session.client("iam") role_name = self.config.execution_role.split("/")[ -1 @@ -633,12 +633,17 @@ def prepare_or_run_pipeline( else "one-time" ) - yield from self.compute_schedule_metadata( - rule_name=rule_name, - schedule_expr=schedule_expr, - pipeline_name=orchestrator_run_name, - next_execution=next_execution, - schedule_type=schedule_type, + schedule_metadata = { + "rule_name": rule_name, + "schedule_type": schedule_type, + "schedule_expr": schedule_expr, + "pipeline_name": orchestrator_run_name, + "next_execution": next_execution, + } + + yield from self.compute_metadata( + execution=schedule_metadata, + settings=settings, ) else: # Execute the pipeline immediately if no schedule is specified @@ -757,7 +762,7 @@ def compute_metadata( """Generate run metadata based on the generated Sagemaker Execution. Args: - execution: The corresponding _PipelineExecution object. + execution: The corresponding _PipelineExecution object or schedule metadata dict. settings: The Sagemaker orchestrator settings. Yields: @@ -766,19 +771,40 @@ def compute_metadata( # Metadata metadata: Dict[str, MetadataType] = {} - # Orchestrator Run ID - if run_id := self._compute_orchestrator_run_id(execution): - metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id + # Handle schedule metadata if execution is a dict + if isinstance(execution, dict): + metadata.update( + { + "schedule_rule_name": execution["rule_name"], + "schedule_type": execution["schedule_type"], + "schedule_expression": execution["schedule_expr"], + "pipeline_name": execution["pipeline_name"], + } + ) + + if next_execution := execution.get("next_execution"): + metadata["next_execution_time"] = next_execution.isoformat() + + # Add orchestrator metadata using the same pattern as execution metadata + if orchestrator_url := self._compute_schedule_url(execution): + metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url) - # URL to the Sagemaker's pipeline view - if orchestrator_url := self._compute_orchestrator_url(execution): - metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url) + if logs_url := self._compute_schedule_logs_url( + execution, settings + ): + metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url) + else: + # Handle execution metadata + if run_id := self._compute_orchestrator_run_id(execution): + metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id + + if orchestrator_url := self._compute_orchestrator_url(execution): + metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url) - # URL to the corresponding CloudWatch page - if logs_url := self._compute_orchestrator_logs_url( - execution, settings - ): - metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url) + if logs_url := self._compute_orchestrator_logs_url( + execution, settings + ): + metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url) yield metadata @@ -876,34 +902,64 @@ def _compute_orchestrator_run_id( ) return None - def compute_schedule_metadata( - self, - rule_name: str, - schedule_expr: str, - pipeline_name: str, - next_execution: Optional[datetime], - schedule_type: str, - ) -> Dict[str, MetadataType]: - """Generate metadata for scheduled pipeline executions. + @staticmethod + def _compute_schedule_url(schedule_info: Dict[str, Any]) -> Optional[str]: + """Generate the SageMaker Console URL for a scheduled pipeline. Args: - rule_name: The name of the EventBridge rule - schedule_expr: The schedule expression (cron or rate) - pipeline_name: Name of the SageMaker pipeline - next_execution: Next scheduled execution time - schedule_type: Type of schedule (cron/rate/one-time) + schedule_info: Dictionary containing schedule information. Returns: - A dictionary of metadata related to the schedule. + The URL to the pipeline in the SageMaker console. """ - metadata: Dict[str, MetadataType] = { - "schedule_rule_name": rule_name, - "schedule_type": schedule_type, - "schedule_expression": schedule_expr, - "pipeline_name": pipeline_name, - } + try: + # Get the Sagemaker session + session = boto3.Session(region_name=schedule_info["region"]) + sagemaker_client = session.client("sagemaker") + + # List the Studio domains and get the Studio Domain ID + domains_response = sagemaker_client.list_domains() + studio_domain_id = domains_response["Domains"][0]["DomainId"] + + return ( + f"https://studio-{studio_domain_id}.studio.{schedule_info['region']}." + f"sagemaker.aws/pipelines/view/{schedule_info['pipeline_name']}" + ) + except Exception as e: + logger.warning( + f"There was an issue while extracting the pipeline url: {e}" + ) + return None + + @staticmethod + def _compute_schedule_logs_url( + schedule_info: Dict[str, Any], + settings: SagemakerOrchestratorSettings, + ) -> Optional[str]: + """Generate the CloudWatch URL for a scheduled pipeline. - if next_execution: - metadata["next_execution_time"] = next_execution.isoformat() + Args: + schedule_info: Dictionary containing schedule information. + settings: The Sagemaker orchestrator settings. - return metadata + Returns: + The URL to query the pipeline logs in CloudWatch. + """ + try: + use_training_jobs = True + if settings.use_training_step is not None: + use_training_jobs = settings.use_training_step + + job_type = "Training" if use_training_jobs else "Processing" + + return ( + f"https://{schedule_info['region']}.console.aws.amazon.com/" + f"cloudwatch/home?region={schedule_info['region']}#logsV2:" + f"log-groups/log-group/$252Faws$252Fsagemaker$252F{job_type}Jobs" + f"$3FlogStreamNameFilter$3Dpipelines-" + ) + except Exception as e: + logger.warning( + f"There was an issue while extracting the logs url: {e}" + ) + return None