From 2bfcf16f32ab50480cf6c47cb7e9c1f5dd36d9dd Mon Sep 17 00:00:00 2001 From: Brian Wylie Date: Sun, 15 Dec 2024 20:01:02 -0700 Subject: [PATCH] filling out pipeline classes and Dashboard pipeline subpage --- .../pages/pipelines/callbacks.py | 23 +-- .../aws_dashboard/pages/pipelines/page.py | 12 +- src/sageworks/api/pipeline.py | 70 ++++---- src/sageworks/api/pipeline_manager.py | 160 ------------------ src/sageworks/cached/cached_pipeline.py | 66 ++++++++ .../core/cloud_platform/aws/aws_meta.py | 144 +++++++++++----- .../core/cloud_platform/cloud_meta.py | 2 +- src/sageworks/repl/sageworks_shell.py | 8 +- src/sageworks/utils/pipeline_utils.py | 72 ++++++++ .../components/plugins/pipeline_details.py | 54 ++++-- .../page_views/pipelines_page_view.py | 80 +++++++++ 11 files changed, 419 insertions(+), 272 deletions(-) delete mode 100644 src/sageworks/api/pipeline_manager.py create mode 100644 src/sageworks/cached/cached_pipeline.py create mode 100644 src/sageworks/utils/pipeline_utils.py create mode 100644 src/sageworks/web_interface/page_views/pipelines_page_view.py diff --git a/applications/aws_dashboard/pages/pipelines/callbacks.py b/applications/aws_dashboard/pages/pipelines/callbacks.py index fbbea9f56..d6558b2f5 100644 --- a/applications/aws_dashboard/pages/pipelines/callbacks.py +++ b/applications/aws_dashboard/pages/pipelines/callbacks.py @@ -9,7 +9,9 @@ # SageWorks Imports -from sageworks.api.pipeline import Pipeline +from sageworks.web_interface.page_views.pipelines_page_view import PipelinesPageView +from sageworks.web_interface.components.plugins.ag_table import AGTable +from sageworks.cached.cached_pipeline import CachedPipeline # Get the SageWorks logger log = logging.getLogger("sageworks") @@ -46,17 +48,18 @@ def _on_page_load(href, row_data, page_already_loaded): raise PreventUpdate -def update_pipelines_table(table_object): +def pipeline_table_refresh(page_view: PipelinesPageView, table: AGTable): @callback( - [Output(component_id, prop) for component_id, prop in table_object.properties], + [Output(component_id, prop) for component_id, prop in table.properties], Input("pipelines_refresh", "n_intervals"), ) - def pipelines_update(_n): + def _pipeline_table_refresh(_n): """Return the table data for the Pipelines Table""" - - # FIXME: This is a placeholder for the actual data - pipelines = pd.DataFrame({"name": ["Pipeline 1", "Pipeline 2", "Pipeline 3"]}) - return table_object.update_properties(pipelines) + page_view.refresh() + pipelines = page_view.pipelines() + pipelines["uuid"] = pipelines["Name"] + pipelines["id"] = range(len(pipelines)) + return table.update_properties(pipelines) # Set up the plugin callbacks that take a pipeline @@ -73,10 +76,10 @@ def update_all_plugin_properties(selected_rows): # Get the selected row data and grab the name selected_row_data = selected_rows[0] - pipeline_name = selected_row_data["name"] + pipeline_name = selected_row_data["Name"] # Create the Endpoint object - pipeline = Pipeline(pipeline_name) + pipeline = CachedPipeline(pipeline_name) # Update all the properties for each plugin all_props = [] diff --git a/applications/aws_dashboard/pages/pipelines/page.py b/applications/aws_dashboard/pages/pipelines/page.py index f3d5d3565..2413e59c8 100644 --- a/applications/aws_dashboard/pages/pipelines/page.py +++ b/applications/aws_dashboard/pages/pipelines/page.py @@ -8,8 +8,9 @@ from . import callbacks # SageWorks Imports -from sageworks.web_interface.components.plugins import ag_table, pipeline_details +from sageworks.web_interface.components.plugins import pipeline_details, ag_table from sageworks.web_interface.components.plugin_interface import PluginPage +from sageworks.web_interface.page_views.pipelines_page_view import PipelinesPageView from sageworks.utils.plugin_manager import PluginManager # Register this page with Dash @@ -28,7 +29,8 @@ details_component = pipeline_details.create_component("pipeline_details") # Capture our components in a dictionary to send off to the layout -components = {"pipelines_table": table_component, "pipeline_details": details_component} +components = {"pipelines_table": table_component, + "pipeline_details": details_component} # Load any web components plugins of type 'pipeline' pm = PluginManager() @@ -42,12 +44,14 @@ # Set up our layout (Dash looks for a var called layout) layout = pipelines_layout(**components) +# Grab a view that gives us a summary of the Pipelines in SageWorks +pipelines_view = PipelinesPageView() + # Callback for anything we want to happen on page load callbacks.on_page_load() # Setup our callbacks/connections -app = dash.get_app() -callbacks.update_pipelines_table(pipeline_table) +callbacks.pipeline_table_refresh(pipelines_view, pipeline_table) # We're going to add the details component to the plugins list plugins.append(pipeline_details) diff --git a/src/sageworks/api/pipeline.py b/src/sageworks/api/pipeline.py index 97df06062..5bdb3500b 100644 --- a/src/sageworks/api/pipeline.py +++ b/src/sageworks/api/pipeline.py @@ -11,6 +11,7 @@ from sageworks.utils.config_manager import ConfigManager from sageworks.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp from sageworks.core.pipelines.pipeline_executor import PipelineExecutor +from sageworks.api.parameter_store import ParameterStore class Pipeline: @@ -29,31 +30,36 @@ class Pipeline: def __init__(self, name: str): """Pipeline Init Method""" self.log = logging.getLogger("sageworks") - self.name = name - - # Grab our SageWorks Bucket from Config - self.cm = ConfigManager() - self.sageworks_bucket = self.cm.get_config("SAGEWORKS_BUCKET") - if self.sageworks_bucket is None: - self.log = logging.getLogger("sageworks") - self.log.critical("Could not find ENV var for SAGEWORKS_BUCKET!") - sys.exit(1) - - # Set the S3 Path for this Pipeline - self.bucket = self.sageworks_bucket - self.key = f"pipelines/{self.name}.json" - self.s3_path = f"s3://{self.bucket}/{self.key}" - - # Grab a SageWorks Session (this allows us to assume the SageWorks ExecutionRole) - self.boto3_session = AWSAccountClamp().boto3_session - self.s3_client = self.boto3_session.client("s3") - - # If this S3 Path exists, load the Pipeline - if wr.s3.does_object_exist(self.s3_path): - self.pipeline = self._get_pipeline() - else: - self.log.warning(f"Pipeline {self.name} not found at {self.s3_path}") - self.pipeline = None + self.uuid = name + + # Spin up a Parameter Store for Pipelines + self.prefix = "/sageworks/pipelines" + self.params = ParameterStore() + self.pipeline = self.params.get(f"{self.prefix}/{self.uuid}") + + def summary(self, **kwargs) -> dict: + """Retrieve the Pipeline Summary. + + Returns: + dict: A dictionary of details about the Pipeline + """ + return self.pipeline + + def details(self, **kwargs) -> dict: + """Retrieve the Pipeline Details. + + Returns: + dict: A dictionary of details about the Pipeline + """ + return self.pipeline + + def health_check(self, **kwargs) -> dict: + """Retrieve the Pipeline Health Check. + + Returns: + dict: A dictionary of health check details for the Pipeline + """ + return {} def set_input(self, input: Union[str, pd.DataFrame], artifact: str = "data_source"): """Set the input for the Pipeline @@ -105,7 +111,7 @@ def report_settable_fields(self, pipeline: dict = {}, path: str = "") -> None: """ # Grab the entire pipeline if not provided (first call) if not pipeline: - self.log.important(f"Checking Pipeline: {self.name}...") + self.log.important(f"Checking Pipeline: {self.uuid}...") pipeline = self.pipeline for key, value in pipeline.items(): if isinstance(value, dict): @@ -118,14 +124,8 @@ def report_settable_fields(self, pipeline: dict = {}, path: str = "") -> None: def delete(self): """Pipeline Deletion""" - self.log.info(f"Deleting Pipeline: {self.name}...") - wr.s3.delete_objects(self.s3_path) - - def _get_pipeline(self) -> dict: - """Internal: Get the pipeline as a JSON object from the specified S3 bucket and key.""" - response = self.s3_client.get_object(Bucket=self.bucket, Key=self.key) - json_object = json.loads(response["Body"].read()) - return json_object + self.log.info(f"Deleting Pipeline: {self.uuid}...") + self.params.delete(f"{self.prefix}/{self.uuid}") def __repr__(self) -> str: """String representation of this pipeline @@ -145,10 +145,12 @@ def __repr__(self) -> str: log = logging.getLogger("sageworks") # Temp testing + """ my_pipeline = Pipeline("aqsol_pipeline_v1") my_pipeline.set_input("s3://sageworks-public-data/comp_chem/aqsol_public_data.csv") my_pipeline.execute_partial(["model", "endpoint"]) exit(0) + """ # Retrieve an existing Pipeline my_pipeline = Pipeline("abalone_pipeline_v1") diff --git a/src/sageworks/api/pipeline_manager.py b/src/sageworks/api/pipeline_manager.py deleted file mode 100644 index 570f6309b..000000000 --- a/src/sageworks/api/pipeline_manager.py +++ /dev/null @@ -1,160 +0,0 @@ -"""PipelineManager: Manages SageWorks Pipelines, listing, creating, and saving them.""" - -import logging -import json - -# SageWorks Imports -from sageworks.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp -from sageworks.api import DataSource, FeatureSet, Model, Endpoint, ParameterStore - - -class PipelineManager: - """PipelineManager: Manages SageWorks Pipelines, listing, creating, and saving them. - - Common Usage: - ```python - my_manager = PipelineManager() - my_manager.list_pipelines() - abalone_pipeline = my_manager.create_from_endpoint("abalone-regression-end") - my_manager.save_pipeline("abalone_pipeline_v1", abalone_pipeline) - ``` - """ - - def __init__(self): - """Pipeline Init Method""" - self.log = logging.getLogger("sageworks") - - # We use the ParameterStore for storage of pipelines - self.prefix = "/sageworks/pipelines/" - self.param_store = ParameterStore() - - # Grab a SageWorks Session (this allows us to assume the SageWorks ExecutionRole) - self.boto3_session = AWSAccountClamp().boto3_session - - def list_pipelines(self) -> list: - """List all the Pipelines in the S3 Bucket - - Returns: - list: A list of Pipeline names and details - """ - # List pipelines stored in the parameter store - return self.param_store.list(self.prefix) - - # Create a new Pipeline from an Endpoint - def create_from_endpoint(self, endpoint_name: str) -> dict: - """Create a Pipeline from an Endpoint - - Args: - endpoint_name (str): The name of the Endpoint - - Returns: - dict: A dictionary of the Pipeline - """ - self.log.important(f"Creating Pipeline from Endpoint: {endpoint_name}...") - pipeline = {} - endpoint = Endpoint(endpoint_name) - model = Model(endpoint.get_input()) - feature_set = FeatureSet(model.get_input()) - data_source = DataSource(feature_set.get_input()) - s3_source = data_source.get_input() - for name in ["data_source", "feature_set", "model", "endpoint"]: - artifact = locals()[name] - pipeline[name] = {"name": artifact.uuid, "tags": artifact.get_tags(), "input": artifact.get_input()} - if name == "model": - pipeline[name]["model_type"] = artifact.model_type.value - pipeline[name]["target_column"] = artifact.target() - pipeline[name]["feature_list"] = artifact.features() - - # Return the Pipeline - return pipeline - - def publish_pipeline(self, name: str, pipeline: dict): - """Publish a Pipeline to Parameter Store - - Args: - name (str): The name of the Pipeline - pipeline (dict): The Pipeline to save - """ - key = f"{self.prefix}/{name}" - self.log.important(f"Saving {name} to Parameter Store {self.prefix}...") - - # Save the pipeline to the parameter store - self.param_store.upsert(key, json.dumps(pipeline)) - - def delete_pipeline(self, name: str): - """Delete a Pipeline from S3 - - Args: - name (str): The name of the Pipeline to delete - """ - key = f"{self.prefix}{name}.json" - self.log.important(f"Deleting {name} from S3: {self.bucket}/{key}...") - - # Delete the pipeline object from S3 - self.s3_client.delete_object(Bucket=self.bucket, Key=key) - - # Save a Pipeline to a local file - def save_pipeline_to_file(self, pipeline: dict, filepath: str): - """Save a Pipeline to a local file - - Args: - pipeline (dict): The Pipeline to save - filepath (str): The path to save the Pipeline - """ - - # Sanity check the filepath - if not filepath.endswith(".json"): - filepath += ".json" - - # Save the pipeline as a local JSON file - with open(filepath, "w") as fp: - json.dump(pipeline, fp, indent=4) - - def load_pipeline_from_file(self, filepath: str) -> dict: - """Load a Pipeline from a local file - - Args: - filepath (str): The path of the Pipeline to load - - Returns: - dict: The Pipeline loaded from the file - """ - - # Load a pipeline as a local JSON file - with open(filepath, "r") as fp: - pipeline = json.load(fp) - return pipeline - - def publish_pipeline_from_file(self, filepath: str): - """Publish a Pipeline to SageWorks from a local file - - Args: - filepath (str): The path of the Pipeline to publish - """ - - # Load a pipeline as a local JSON file - pipeline = self.load_pipeline_from_file(filepath) - - # Get the pipeline name - pipeline_name = filepath.split("/")[-1].replace(".json", "") - - # Publish the Pipeline - self.publish_pipeline(pipeline_name, pipeline) - - -if __name__ == "__main__": - """Exercise the Pipeline Class""" - from pprint import pprint - - # Create a PipelineManager - my_manager = PipelineManager() - - # List the Pipelines - print("Listing Pipelines...") - pprint(my_manager.list_pipelines()) - - # Create a Pipeline from an Endpoint - abalone_pipeline = my_manager.create_from_endpoint("abalone-regression-end") - - # Publish the Pipeline - my_manager.publish_pipeline("abalone_pipeline_v1", abalone_pipeline) diff --git a/src/sageworks/cached/cached_pipeline.py b/src/sageworks/cached/cached_pipeline.py new file mode 100644 index 000000000..d549c1d22 --- /dev/null +++ b/src/sageworks/cached/cached_pipeline.py @@ -0,0 +1,66 @@ +"""CachedPipeline: Caches the method results for SageWorks Pipelines""" + +from typing import Union + +# SageWorks Imports +from sageworks.api.pipeline import Pipeline +from sageworks.core.artifacts.cached_artifact_mixin import CachedArtifactMixin + + +class CachedPipeline(CachedArtifactMixin, Pipeline): + """CachedPipeline: Caches the method results for SageWorks Pipelines + + Note: Cached method values may lag underlying Pipeline changes. + + Common Usage: + ```python + my_pipeline = CachedPipeline(name) + my_pipeline.details() + my_pipeline.health_check() + ``` + """ + + def __init__(self, pipeline_uuid: str): + """CachedPipeline Initialization""" + Pipeline.__init__(self, name=pipeline_uuid) + + @CachedArtifactMixin.cache_result + def summary(self, **kwargs) -> dict: + """Retrieve the CachedPipeline Details. + + Returns: + dict: A dictionary of details about the CachedPipeline + """ + return super().summary(**kwargs) + + @CachedArtifactMixin.cache_result + def details(self, **kwargs) -> dict: + """Retrieve the CachedPipeline Details. + + Returns: + dict: A dictionary of details about the CachedPipeline + """ + return super().details(**kwargs) + + @CachedArtifactMixin.cache_result + def health_check(self, **kwargs) -> dict: + """Retrieve the CachedPipeline Health Check. + + Returns: + dict: A dictionary of health check details for the CachedPipeline + """ + return super().health_check(**kwargs) + + +if __name__ == "__main__": + """Exercise the CachedPipeline Class""" + from pprint import pprint + + # Retrieve an existing Pipeline + my_pipeline = CachedPipeline("abalone_pipeline_v1") + pprint(my_pipeline.summary()) + pprint(my_pipeline.details()) + pprint(my_pipeline.health_check()) + + # Shutdown the ThreadPoolExecutor (note: users should NOT call this) + my_pipeline._shutdown() diff --git a/src/sageworks/core/cloud_platform/aws/aws_meta.py b/src/sageworks/core/cloud_platform/aws/aws_meta.py index e9d631c1b..f25240c07 100644 --- a/src/sageworks/core/cloud_platform/aws/aws_meta.py +++ b/src/sageworks/core/cloud_platform/aws/aws_meta.py @@ -8,12 +8,14 @@ import pandas as pd import awswrangler as wr from collections import defaultdict +from datetime import datetime, timezone # SageWorks Imports from sageworks.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp from sageworks.utils.config_manager import ConfigManager from sageworks.utils.datetime_utils import datetime_string from sageworks.utils.aws_utils import not_found_returns_none, aws_throttle, aws_tags_to_dict +from sageworks.api.parameter_store import ParameterStore class AWSMeta: @@ -31,6 +33,10 @@ def __init__(self): self.account_clamp = AWSAccountClamp() self.cm = ConfigManager() + # Parameter Store for Pipelines + self.pipeline_prefix = "/sageworks/pipelines" + self.param_store = ParameterStore() + # Storing the size of various metadata for tracking self.metadata_sizes = defaultdict(dict) @@ -305,55 +311,32 @@ def endpoints(self, refresh: bool = False) -> pd.DataFrame: # Return the summary as a DataFrame return pd.DataFrame(data_summary).convert_dtypes() - def aws_pipelines(self) -> pd.DataFrame: - """Get a summary of the Pipelines deployed in the Cloud Platform. + def pipelines(self) -> pd.DataFrame: + """List all the Pipelines in the S3 Bucket Returns: - pd.DataFrame: A summary of the Pipelines in the Cloud Platform. + pd.DataFrame: A dataframe of Pipelines information """ - import pandas as pd - - # Initialize the SageMaker client and list all pipelines - sagemaker_client = self.boto3_session.client("sagemaker") - data_summary = [] - - # List all pipelines - pipelines = sagemaker_client.list_pipelines()["PipelineSummaries"] - - # Loop through each pipeline to get its executions - for pipeline in pipelines: - pipeline_name = pipeline["PipelineName"] + # List pipelines stored in the parameter store + pipeline_summaries = [] + pipeline_list = self.param_store.list(self.pipeline_prefix) + for pipeline_name in pipeline_list: + pipeline_info = self.param_store.get(pipeline_name) - # Use paginator to retrieve all executions for this pipeline - paginator = sagemaker_client.get_paginator("list_pipeline_executions") - for page in paginator.paginate(PipelineName=pipeline_name): - for execution in page["PipelineExecutionSummaries"]: - pipeline_execution_arn = execution["PipelineExecutionArn"] - - # Get detailed information about the pipeline execution - pipeline_info = sagemaker_client.describe_pipeline_execution( - PipelineExecutionArn=pipeline_execution_arn - ) - - # Retrieve SageWorks metadata from tags - sageworks_meta = self.get_aws_tags(pipeline_execution_arn) - health_tags = sageworks_meta.get("sageworks_health_tags", "") - - # Compile pipeline summary - summary = { - "Name": pipeline_name, - "ExecutionName": execution["PipelineExecutionDisplayName"], - "Health": health_tags, - "Created": datetime_string(pipeline_info.get("CreationTime")), - "Tags": sageworks_meta.get("sageworks_tags", "-"), - "Input": sageworks_meta.get("sageworks_input", "-"), - "Status": pipeline_info["PipelineExecutionStatus"], - "PipelineArn": pipeline_execution_arn, - } - data_summary.append(summary) + # Compile pipeline summary + summary = { + "Name": pipeline_name.replace(self.pipeline_prefix + "/", ""), + "Health": "", + "Num Stages": len(pipeline_info), + "Tags": pipeline_info.get("tags", "-"), + "Modified": datetime_string(datetime.now(timezone.utc)), + "Last Run": datetime_string(datetime.now(timezone.utc)), + "Status": "Success", # pipeline_info.get("Status", "-"), + } + pipeline_summaries.append(summary) # Return the summary as a DataFrame - return pd.DataFrame(data_summary).convert_dtypes() + return pd.DataFrame(pipeline_summaries).convert_dtypes() @not_found_returns_none def glue_job(self, job_name: str) -> Union[dict, None]: @@ -487,6 +470,18 @@ def endpoint(self, endpoint_name: str) -> Union[dict, None]: endpoint_details["sageworks_meta"] = self.get_aws_tags(endpoint_details["EndpointArn"]) return endpoint_details + @not_found_returns_none + def pipeline(self, pipeline_name: str) -> Union[dict, None]: + """Describe a single SageWorks Pipeline. + + Args: + pipeline_name (str): The name of the pipeline to describe. + + Returns: + dict: A detailed description of the pipeline (None if not found). + """ + return self.param_store.get(f"{self.pipeline_prefix}/{pipeline_name}") + # These are helper methods to construct the AWS URL for the Artifacts @staticmethod def s3_to_console_url(s3_path: str) -> str: @@ -625,6 +620,56 @@ def _list_catalog_tables(self, database: str, views: bool = False) -> pd.DataFra return pd.DataFrame(data_summary).convert_dtypes() + def _aws_pipelines(self) -> pd.DataFrame: + """Internal: Get a summary of the Cloud internal Pipelines (not SageWorks Pipelines). + + Returns: + pd.DataFrame: A summary of the Cloud internal Pipelines (not SageWorks Pipelines). + """ + import pandas as pd + + # Initialize the SageMaker client and list all pipelines + sagemaker_client = self.boto3_session.client("sagemaker") + data_summary = [] + + # List all pipelines + pipelines = sagemaker_client.list_pipelines()["PipelineSummaries"] + + # Loop through each pipeline to get its executions + for pipeline in pipelines: + pipeline_name = pipeline["PipelineName"] + + # Use paginator to retrieve all executions for this pipeline + paginator = sagemaker_client.get_paginator("list_pipeline_executions") + for page in paginator.paginate(PipelineName=pipeline_name): + for execution in page["PipelineExecutionSummaries"]: + pipeline_execution_arn = execution["PipelineExecutionArn"] + + # Get detailed information about the pipeline execution + pipeline_info = sagemaker_client.describe_pipeline_execution( + PipelineExecutionArn=pipeline_execution_arn + ) + + # Retrieve SageWorks metadata from tags + sageworks_meta = self.get_aws_tags(pipeline_execution_arn) + health_tags = sageworks_meta.get("sageworks_health_tags", "") + + # Compile pipeline summary + summary = { + "Name": pipeline_name, + "ExecutionName": execution["PipelineExecutionDisplayName"], + "Health": health_tags, + "Created": datetime_string(pipeline_info.get("CreationTime")), + "Tags": sageworks_meta.get("sageworks_tags", "-"), + "Input": sageworks_meta.get("sageworks_input", "-"), + "Status": pipeline_info["PipelineExecutionStatus"], + "PipelineArn": pipeline_execution_arn, + } + data_summary.append(summary) + + # Return the summary as a DataFrame + return pd.DataFrame(data_summary).convert_dtypes() + def close(self): """Close the AWSMeta Class""" self.log.debug("Closing the AWSMeta Class") @@ -647,6 +692,8 @@ def __repr__(self): # Test the __repr__ method print(meta) + """ + # Get the AWS Account Info print("*** AWS Account ***") pprint(meta.account()) @@ -676,7 +723,6 @@ def __repr__(self): fs_views = meta.views("sagemaker_featurestore") print(fs_views) - """ # Get the Feature Sets print("\n\n*** Feature Sets ***") pprint(meta.feature_sets()) @@ -698,9 +744,13 @@ def __repr__(self): pprint(meta.endpoints()) """ - # Get the Pipelines - print("\n\n*** AWS Pipelines ***") - pprint(meta.aws_pipelines()) + # List Pipelines + print("\n\n*** SageWorks Pipelines ***") + pprint(meta.pipelines()) + + # Get one pipeline + print("\n\n*** Pipeline details ***") + pprint(meta.pipeline("abalone_pipeline_v1")) # Test out the specific artifact details methods print("\n\n*** Glue Job Details ***") diff --git a/src/sageworks/core/cloud_platform/cloud_meta.py b/src/sageworks/core/cloud_platform/cloud_meta.py index 65517eedf..560eb3c3d 100644 --- a/src/sageworks/core/cloud_platform/cloud_meta.py +++ b/src/sageworks/core/cloud_platform/cloud_meta.py @@ -135,7 +135,7 @@ def pipelines(self) -> pd.DataFrame: Returns: pd.DataFrame: A summary of the Pipelines in the Cloud Platform """ - return super().aws_pipelines() + return super().pipelines() def glue_job(self, job_name: str) -> Union[dict, None]: """Get the details of a specific Glue Job diff --git a/src/sageworks/repl/sageworks_shell.py b/src/sageworks/repl/sageworks_shell.py index 11aa3faea..b4e0eba7e 100644 --- a/src/sageworks/repl/sageworks_shell.py +++ b/src/sageworks/repl/sageworks_shell.py @@ -261,10 +261,7 @@ def import_sageworks(self): ).ComputationView self.commands["MDQView"] = importlib.import_module("sageworks.core.views.mdq_view").MDQView self.commands["PandasToView"] = importlib.import_module("sageworks.core.views.pandas_to_view").PandasToView - - # We're going to include these classes/imports later - # self.commands["Pipeline"] = importlib.import_module("sageworks.api.pipeline").Pipeline - # self.commands["PipelineManager"] = tbd + self.commands["Pipeline"] = importlib.import_module("sageworks.api.pipeline").Pipeline # These are 'nice to have' imports self.commands["pd"] = importlib.import_module("pandas") @@ -370,8 +367,7 @@ def endpoints(self): return self.meta.endpoints() def pipelines(self): - logging.error("Pipelines are not yet supported in the SageWorks REPL") - return pd.DataFrame() + return self.meta.pipelines() @staticmethod def log_debug(): diff --git a/src/sageworks/utils/pipeline_utils.py b/src/sageworks/utils/pipeline_utils.py new file mode 100644 index 000000000..a05418d3b --- /dev/null +++ b/src/sageworks/utils/pipeline_utils.py @@ -0,0 +1,72 @@ +"""SageWorks Pipeline Utilities""" + +import logging +import json + +# SageWorks Imports +from sageworks.api import DataSource, FeatureSet, Model, Endpoint, ParameterStore + + +# Set up the logging +log = logging.getLogger("sageworks") + + +# Create a new Pipeline from an Endpoint +def create_from_endpoint(endpoint_name: str) -> dict: + """Create a Pipeline from an Endpoint + + Args: + endpoint_name (str): The name of the Endpoint + + Returns: + dict: A dictionary of the Pipeline + """ + log.important(f"Creating Pipeline from Endpoint: {endpoint_name}...") + pipeline = {} + endpoint = Endpoint(endpoint_name) + model = Model(endpoint.get_input()) + feature_set = FeatureSet(model.get_input()) + data_source = DataSource(feature_set.get_input()) + s3_source = data_source.get_input() + for name in ["data_source", "feature_set", "model", "endpoint"]: + artifact = locals()[name] + pipeline[name] = {"name": artifact.uuid, "tags": artifact.get_tags(), "input": artifact.get_input()} + if name == "model": + pipeline[name]["model_type"] = artifact.model_type.value + pipeline[name]["target_column"] = artifact.target() + pipeline[name]["feature_list"] = artifact.features() + + # Return the Pipeline + return pipeline + + +def publish_pipeline(name: str, pipeline: dict): + """Publish a Pipeline to Parameter Store + + Args: + name (str): The name of the Pipeline + pipeline (dict): The Pipeline to save + """ + params = ParameterStore() + key = f"/sageworks/pipelines/{name}" + log.important(f"Saving {name} to Parameter Store {key}...") + + # Save the pipeline to the parameter store + params.upsert(key, json.dumps(pipeline)) + + +if __name__ == "__main__": + """Exercise the Pipeline Class""" + from pprint import pprint + from sageworks.api.meta import Meta + + # List the Pipelines + meta = Meta() + print("Listing Pipelines...") + pprint(meta.pipelines()) + + # Create a Pipeline from an Endpoint + abalone_pipeline = create_from_endpoint("abalone-regression-end") + + # Publish the Pipeline + publish_pipeline("abalone_pipeline_test", abalone_pipeline) diff --git a/src/sageworks/web_interface/components/plugins/pipeline_details.py b/src/sageworks/web_interface/components/plugins/pipeline_details.py index 8de9b3e48..eaa8e5ae2 100644 --- a/src/sageworks/web_interface/components/plugins/pipeline_details.py +++ b/src/sageworks/web_interface/components/plugins/pipeline_details.py @@ -7,6 +7,7 @@ # SageWorks Imports from sageworks.api.pipeline import Pipeline +from sageworks.utils.markdown_utils import health_tag_markdown from sageworks.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType # Get the SageWorks logger @@ -20,6 +21,14 @@ class PipelineDetails(PluginInterface): auto_load_page = PluginPage.NONE plugin_input_type = PluginInputType.PIPELINE + def __init__(self): + """Initialize the PipelineDetails plugin class""" + self.component_id = None + self.current_pipeline = None + + # Call the parent class constructor + super().__init__() + def create_component(self, component_id: str) -> html.Div: """Create a Markdown Component without any data. Args: @@ -27,18 +36,19 @@ def create_component(self, component_id: str) -> html.Div: Returns: html.Div: A Container of Components for the Model Details """ + self.component_id = component_id container = html.Div( - id=component_id, + id=self.component_id, children=[ - html.H3(id=f"{component_id}-header", children="Pipeline: Loading..."), - dcc.Markdown(id=f"{component_id}-details"), + html.H3(id=f"{self.component_id}-header", children="Pipeline: Loading..."), + dcc.Markdown(id=f"{self.component_id}-details"), ], ) # Fill in plugin properties self.properties = [ - (f"{component_id}-header", "children"), - (f"{component_id}-details", "children"), + (f"{self.component_id}-header", "children"), + (f"{self.component_id}-details", "children"), ] # Return the container @@ -54,17 +64,41 @@ def update_properties(self, pipeline: Pipeline, **kwargs) -> list: Returns: list: A list of the updated property values for the plugin """ - log.important(f"Updating Plugin with Pipeline: {pipeline.name} and kwargs: {kwargs}") + log.important(f"Updating Plugin with Pipeline: {pipeline.uuid} and kwargs: {kwargs}") # Update the header and the details - header = f"{pipeline.name}" - # pipeline_data = pipeline.get_pipeline_data() - details = "**Details:**\n" - details += f"**Name:** {pipeline.name}\n" + self.current_pipeline = pipeline + header = f"{self.current_pipeline.uuid}" + details = self.pipeline_details() # Return the updated property values for the plugin return [header, details] + def pipeline_details(self): + """Construct the markdown string for the pipeline details + + Returns: + str: A markdown string + """ + + # Construct the markdown string + details = self.current_pipeline.details() + markdown = "" + for key, value in details.items(): + + # If the value is a list, convert it to a comma-separated string + if isinstance(value, list): + value = ", ".join(value) + + # If the value is a dictionary, get the name + if isinstance(value, dict): + value = value.get("name", "Unknown") + + # Add to markdown string + markdown += f"**{key}:** {value} \n" + + return markdown + if __name__ == "__main__": # This class takes in pipeline details and generates a details Markdown component diff --git a/src/sageworks/web_interface/page_views/pipelines_page_view.py b/src/sageworks/web_interface/page_views/pipelines_page_view.py new file mode 100644 index 000000000..295605786 --- /dev/null +++ b/src/sageworks/web_interface/page_views/pipelines_page_view.py @@ -0,0 +1,80 @@ +"""PipelinesPageView pulls Pipeline metadata from the AWS Service Broker with Details Panels on each Pipeline""" + +import pandas as pd + +# SageWorks Imports +from sageworks.web_interface.page_views.page_view import PageView +from sageworks.cached.cached_meta import CachedMeta +from sageworks.cached.cached_pipeline import CachedPipeline +from sageworks.utils.symbols import tag_symbols + + +class PipelinesPageView(PageView): + def __init__(self): + """PipelinesPageView pulls Pipeline metadata and populates a Details Panel""" + # Call SuperClass Initialization + super().__init__() + + # CachedMeta object for Cloud Platform Metadata + self.meta = CachedMeta() + + # Initialize the Pipelines DataFrame + self.pipelines_df = None + self.refresh() + + def refresh(self): + """Refresh the pipeline data from the Cloud Platform""" + self.log.important("Calling refresh()..") + self.pipelines_df = self.meta.pipelines() + + # Drop the AWS URL column + self.pipelines_df.drop(columns=["_aws_url"], inplace=True, errors="ignore") + # Add Health Symbols to the Model Group Name + if "Health" in self.pipelines_df.columns: + self.pipelines_df["Health"] = self.pipelines_df["Health"].map(lambda x: tag_symbols(x)) + + def pipelines(self) -> pd.DataFrame: + """Get all the data that's useful for this view + + Returns: + pd.DataFrame: DataFrame of the Pipelines View Data + """ + return self.pipelines_df + + @staticmethod + def pipeline_details(pipeline_uuid: str) -> (dict, None): + """Get all the details for the given Pipeline UUID + Args: + pipeline_uuid(str): The UUID of the Pipeline + Returns: + dict: The details for the given Model (or None if not found) + """ + pipeline = CachedPipeline(pipeline_uuid) + if pipeline is None: + return {"Status": "Not Found"} + + # Return the Pipeline Details + return pipeline.details() + + +if __name__ == "__main__": + # Exercising the PipelinesPageView + import time + from pprint import pprint + + # Create the class and get the AWS Pipeline details + pipeline_view = PipelinesPageView() + + # List the Pipelines + print("PipelinesSummary:") + summary = pipeline_view.pipelines() + print(summary.head()) + + # Get the details for the first Pipeline + my_pipeline_uuid = summary["Name"].iloc[0] + print("\nPipelineDetails:") + details = pipeline_view.pipeline_details(my_pipeline_uuid) + pprint(details) + + # Give any broker threads time to finish + time.sleep(1)