Skip to content

Commit

Permalink
filling out pipeline classes and Dashboard pipeline subpage
Browse files Browse the repository at this point in the history
  • Loading branch information
brifordwylie committed Dec 16, 2024
1 parent dcbb112 commit 2bfcf16
Show file tree
Hide file tree
Showing 11 changed files with 419 additions and 272 deletions.
23 changes: 13 additions & 10 deletions applications/aws_dashboard/pages/pipelines/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


# SageWorks Imports
from sageworks.api.pipeline import Pipeline
from sageworks.web_interface.page_views.pipelines_page_view import PipelinesPageView
from sageworks.web_interface.components.plugins.ag_table import AGTable
from sageworks.cached.cached_pipeline import CachedPipeline

# Get the SageWorks logger
log = logging.getLogger("sageworks")
Expand Down Expand Up @@ -46,17 +48,18 @@ def _on_page_load(href, row_data, page_already_loaded):
raise PreventUpdate


def update_pipelines_table(table_object):
def pipeline_table_refresh(page_view: PipelinesPageView, table: AGTable):
@callback(
[Output(component_id, prop) for component_id, prop in table_object.properties],
[Output(component_id, prop) for component_id, prop in table.properties],
Input("pipelines_refresh", "n_intervals"),
)
def pipelines_update(_n):
def _pipeline_table_refresh(_n):
"""Return the table data for the Pipelines Table"""

# FIXME: This is a placeholder for the actual data
pipelines = pd.DataFrame({"name": ["Pipeline 1", "Pipeline 2", "Pipeline 3"]})
return table_object.update_properties(pipelines)
page_view.refresh()
pipelines = page_view.pipelines()
pipelines["uuid"] = pipelines["Name"]
pipelines["id"] = range(len(pipelines))
return table.update_properties(pipelines)


# Set up the plugin callbacks that take a pipeline
Expand All @@ -73,10 +76,10 @@ def update_all_plugin_properties(selected_rows):

# Get the selected row data and grab the name
selected_row_data = selected_rows[0]
pipeline_name = selected_row_data["name"]
pipeline_name = selected_row_data["Name"]

# Create the Endpoint object
pipeline = Pipeline(pipeline_name)
pipeline = CachedPipeline(pipeline_name)

# Update all the properties for each plugin
all_props = []
Expand Down
12 changes: 8 additions & 4 deletions applications/aws_dashboard/pages/pipelines/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from . import callbacks

# SageWorks Imports
from sageworks.web_interface.components.plugins import ag_table, pipeline_details
from sageworks.web_interface.components.plugins import pipeline_details, ag_table
from sageworks.web_interface.components.plugin_interface import PluginPage
from sageworks.web_interface.page_views.pipelines_page_view import PipelinesPageView
from sageworks.utils.plugin_manager import PluginManager

# Register this page with Dash
Expand All @@ -28,7 +29,8 @@
details_component = pipeline_details.create_component("pipeline_details")

# Capture our components in a dictionary to send off to the layout
components = {"pipelines_table": table_component, "pipeline_details": details_component}
components = {"pipelines_table": table_component,
"pipeline_details": details_component}

# Load any web components plugins of type 'pipeline'
pm = PluginManager()
Expand All @@ -42,12 +44,14 @@
# Set up our layout (Dash looks for a var called layout)
layout = pipelines_layout(**components)

# Grab a view that gives us a summary of the Pipelines in SageWorks
pipelines_view = PipelinesPageView()

# Callback for anything we want to happen on page load
callbacks.on_page_load()

# Setup our callbacks/connections
app = dash.get_app()
callbacks.update_pipelines_table(pipeline_table)
callbacks.pipeline_table_refresh(pipelines_view, pipeline_table)

# We're going to add the details component to the plugins list
plugins.append(pipeline_details)
Expand Down
70 changes: 36 additions & 34 deletions src/sageworks/api/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sageworks.utils.config_manager import ConfigManager
from sageworks.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
from sageworks.core.pipelines.pipeline_executor import PipelineExecutor
from sageworks.api.parameter_store import ParameterStore


class Pipeline:
Expand All @@ -29,31 +30,36 @@ class Pipeline:
def __init__(self, name: str):
"""Pipeline Init Method"""
self.log = logging.getLogger("sageworks")
self.name = name

# Grab our SageWorks Bucket from Config
self.cm = ConfigManager()
self.sageworks_bucket = self.cm.get_config("SAGEWORKS_BUCKET")
if self.sageworks_bucket is None:
self.log = logging.getLogger("sageworks")
self.log.critical("Could not find ENV var for SAGEWORKS_BUCKET!")
sys.exit(1)

# Set the S3 Path for this Pipeline
self.bucket = self.sageworks_bucket
self.key = f"pipelines/{self.name}.json"
self.s3_path = f"s3://{self.bucket}/{self.key}"

# Grab a SageWorks Session (this allows us to assume the SageWorks ExecutionRole)
self.boto3_session = AWSAccountClamp().boto3_session
self.s3_client = self.boto3_session.client("s3")

# If this S3 Path exists, load the Pipeline
if wr.s3.does_object_exist(self.s3_path):
self.pipeline = self._get_pipeline()
else:
self.log.warning(f"Pipeline {self.name} not found at {self.s3_path}")
self.pipeline = None
self.uuid = name

# Spin up a Parameter Store for Pipelines
self.prefix = "/sageworks/pipelines"
self.params = ParameterStore()
self.pipeline = self.params.get(f"{self.prefix}/{self.uuid}")

def summary(self, **kwargs) -> dict:
"""Retrieve the Pipeline Summary.
Returns:
dict: A dictionary of details about the Pipeline
"""
return self.pipeline

def details(self, **kwargs) -> dict:
"""Retrieve the Pipeline Details.
Returns:
dict: A dictionary of details about the Pipeline
"""
return self.pipeline

def health_check(self, **kwargs) -> dict:
"""Retrieve the Pipeline Health Check.
Returns:
dict: A dictionary of health check details for the Pipeline
"""
return {}

def set_input(self, input: Union[str, pd.DataFrame], artifact: str = "data_source"):
"""Set the input for the Pipeline
Expand Down Expand Up @@ -105,7 +111,7 @@ def report_settable_fields(self, pipeline: dict = {}, path: str = "") -> None:
"""
# Grab the entire pipeline if not provided (first call)
if not pipeline:
self.log.important(f"Checking Pipeline: {self.name}...")
self.log.important(f"Checking Pipeline: {self.uuid}...")
pipeline = self.pipeline
for key, value in pipeline.items():
if isinstance(value, dict):
Expand All @@ -118,14 +124,8 @@ def report_settable_fields(self, pipeline: dict = {}, path: str = "") -> None:

def delete(self):
"""Pipeline Deletion"""
self.log.info(f"Deleting Pipeline: {self.name}...")
wr.s3.delete_objects(self.s3_path)

def _get_pipeline(self) -> dict:
"""Internal: Get the pipeline as a JSON object from the specified S3 bucket and key."""
response = self.s3_client.get_object(Bucket=self.bucket, Key=self.key)
json_object = json.loads(response["Body"].read())
return json_object
self.log.info(f"Deleting Pipeline: {self.uuid}...")
self.params.delete(f"{self.prefix}/{self.uuid}")

def __repr__(self) -> str:
"""String representation of this pipeline
Expand All @@ -145,10 +145,12 @@ def __repr__(self) -> str:
log = logging.getLogger("sageworks")

# Temp testing
"""
my_pipeline = Pipeline("aqsol_pipeline_v1")
my_pipeline.set_input("s3://sageworks-public-data/comp_chem/aqsol_public_data.csv")
my_pipeline.execute_partial(["model", "endpoint"])
exit(0)
"""

# Retrieve an existing Pipeline
my_pipeline = Pipeline("abalone_pipeline_v1")
Expand Down
160 changes: 0 additions & 160 deletions src/sageworks/api/pipeline_manager.py

This file was deleted.

Loading

0 comments on commit 2bfcf16

Please sign in to comment.