Skip to content

Commit

Permalink
Merge pull request #447 from SuperCowPowers/signals_and_slots
Browse files Browse the repository at this point in the history
Signals and slots
  • Loading branch information
brifordwylie authored Apr 10, 2024
2 parents 3c7e0c9 + 751c00f commit e85d3f6
Show file tree
Hide file tree
Showing 25 changed files with 630 additions and 244 deletions.
Binary file modified applications/aws_dashboard/assets/favicon.ico
Binary file not shown.
34 changes: 10 additions & 24 deletions applications/aws_dashboard/pages/endpoints/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Callbacks for the Endpoints Subpage Web User Interface"""

import logging
from dash import Dash, no_update
from dash.dependencies import Input, Output, State

# SageWorks Imports
from sageworks.views.endpoint_web_view import EndpointWebView
from sageworks.web_components import table, endpoint_metric_plots
from sageworks.web_components import table, endpoint_metric_plots, plugin_callbacks
from sageworks.utils.pandas_utils import deserialize_aws_broker_data
from sageworks.api.endpoint import Endpoint

# Get the SageWorks logger
log = logging.getLogger("sageworks")


def update_endpoints_table(app: Dash):
Expand Down Expand Up @@ -80,25 +83,8 @@ def generate_endpoint_details_figures(selected_rows, table_data):
return endpoint_metrics_figure


# Updates the plugin component when a endpoint row is selected
def update_plugin(app: Dash, plugin):
@app.callback(
Output(plugin.component_id(), "figure"),
Input("endpoints_table", "derived_viewport_selected_row_ids"),
State("endpoints_table", "data"),
prevent_initial_call=True,
)
def update_callback(selected_rows, table_data):
# Check for no selected rows
if not selected_rows or selected_rows[0] is None:
return no_update

# Get the selected row data and grab the uuid
selected_row_data = table_data[selected_rows[0]]
endpoint_uuid = selected_row_data["uuid"]

# Instantiate the Endpoint and send it to the plugin
endpoint = Endpoint(endpoint_uuid, legacy=True)

# Instantiate the Endpoint and send it to the plugin
return plugin.update_contents(endpoint)
# Updates the plugin components when a model row is selected
def update_plugins(plugins):
# Setup the inputs for the plugins and register the callbacks
endpoint_inputs = [Input("endpoints_table", "derived_viewport_selected_row_ids"), State("endpoints_table", "data")]
plugin_callbacks.register_callbacks(plugins, endpoint_inputs, "endpoint")
1 change: 1 addition & 0 deletions applications/aws_dashboard/pages/endpoints/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def endpoints_layout(
[
html.H2("SageWorks: Endpoints"),
dbc.Row(style={"padding": "30px 0px 0px 0px"}),
html.Div(id="dev_null", style={"display": "none"}), # Output for callbacks without outputs
]
),
# A table that lists out all the Models
Expand Down
7 changes: 3 additions & 4 deletions applications/aws_dashboard/pages/endpoints/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

# Add the plugins to the components dictionary
for plugin in plugins:
component_id = plugin.component_id()
component_id = plugin.generate_component_id()
components[component_id] = plugin.create_component(component_id)

# Set up our layout (Dash looks for a var called layout)
Expand All @@ -67,6 +67,5 @@
callbacks.table_row_select(app, "endpoints_table")
callbacks.update_endpoint_metrics(app, endpoint_broker)

# For each plugin, set up a callback to update the plugin figure
for plugin in plugins:
callbacks.update_plugin(app, plugin)
# For all the plugins we have we'll call their update_contents method
callbacks.update_plugins(plugins)
36 changes: 13 additions & 23 deletions applications/aws_dashboard/pages/models/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""Callbacks for the Model Subpage Web User Interface"""

import logging
from dash import Dash, no_update
from dash.dependencies import Input, Output, State

# SageWorks Imports
from sageworks.web_components import (
table,
model_plot,
)
from sageworks.web_components import table, model_plot, plugin_callbacks
from sageworks.utils.pandas_utils import deserialize_aws_broker_data
from sageworks.api.model import Model

# Get the SageWorks logger
log = logging.getLogger("sageworks")


def update_models_table(app: Dash):
@app.callback(
Expand Down Expand Up @@ -76,23 +77,12 @@ def generate_model_plot_figure(inference_run, table_data, selected_rows):
return model_plot_fig


# Updates the plugin component when a model row is selected
def update_plugin(app: Dash, plugin):
@app.callback(
Output(plugin.component_id(), "figure"),
[Input("model_details-dropdown", "value"), Input("models_table", "derived_viewport_selected_row_ids")],
# Updates the plugin components when a model row is selected
def update_plugins(plugins):
# Setup the inputs for the plugins and register the callbacks
model_inputs = [
Input("model_details-dropdown", "value"),
Input("models_table", "derived_viewport_selected_row_ids"),
State("models_table", "data"),
prevent_initial_call=True,
)
def update_plugin_figure(inference_run, selected_rows, table_data):
# Check for no selected rows
if not selected_rows or selected_rows[0] is None:
return no_update

# Get the selected row data and grab the uuid
selected_row_data = table_data[selected_rows[0]]
model_uuid = selected_row_data["uuid"]

# Instantiate the Model and send it to the plugin
model = Model(model_uuid, legacy=True)
return plugin.update_contents(model, inference_run=inference_run)
]
plugin_callbacks.register_callbacks(plugins, model_inputs, "model")
1 change: 1 addition & 0 deletions applications/aws_dashboard/pages/models/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def models_layout(
[
html.H2("SageWorks: Models"),
dbc.Row(style={"padding": "30px 0px 0px 0px"}),
html.Div(id="dev_null", style={"display": "none"}), # Output for callbacks without outputs
]
),
# A table that lists out all the Models
Expand Down
7 changes: 3 additions & 4 deletions applications/aws_dashboard/pages/models/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

# Add the plugins to the components dictionary
for plugin in plugins:
component_id = plugin.component_id()
component_id = plugin.generate_component_id()
components[component_id] = plugin.create_component(component_id)

# Set up our layout (Dash looks for a var called layout)
Expand All @@ -67,6 +67,5 @@
callbacks.table_row_select(app, "models_table")
callbacks.update_model_plot_component(app)

# For each plugin, set up a callback to update the plugin figure
for plugin in plugins:
callbacks.update_plugin(app, plugin)
# Set up callbacks for all the plugins
callbacks.update_plugins(plugins)
10 changes: 5 additions & 5 deletions examples/plugins/pages/plugin_page_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def page_setup(self, app: dash.Dash):

# Register this page with Dash and set up the layout (required)
register_page(
__name__,
"plugin",
path="/plugin_1",
name=self.page_name,
layout=self.page_layout(),
Expand All @@ -34,7 +34,7 @@ def page_layout(self) -> dash.html.Div:
import webbrowser

# Create our Dash Application
app = dash.Dash(
my_app = dash.Dash(
__name__,
title="SageWorks Dashboard",
use_pages=True,
Expand All @@ -43,14 +43,14 @@ def page_layout(self) -> dash.html.Div:
)

# For Multi-Page Applications, we need to create a 'page container' to hold all the pages
app.layout = html.Div([page_container])
my_app.layout = html.Div([page_container])

# Create the Plugin Page and call page_setup
plugin_page = PluginPage1()
plugin_page.page_setup(app)
plugin_page.page_setup(my_app)

# Open the browser to the plugin page
webbrowser.open("http://localhost:8000/plugin_1")

# Note: This 'main' is purely for running/testing locally
app.run(host="0.0.0.0", port=8000, debug=True)
my_app.run(host="localhost", port=8000, debug=True)
14 changes: 7 additions & 7 deletions examples/plugins/pages/plugin_page_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class PluginPage2:
"""Plugin Page: A SageWorks Plugin Web Interface"""
"""Plugin Page: A SageWorks Plugin Page Interface"""

def __init__(self):
"""Initialize the Plugin Page"""
Expand All @@ -27,9 +27,9 @@ def page_setup(self, app: dash.Dash):
"my_model_table", header_color="rgb(60, 60, 60)", row_select="single", max_height=400
)

# Register this page with Dash and set up the layout (required)
# Register this page with Dash and set up the layout
register_page(
__name__,
"plugin",
path="/plugin_2",
name=self.page_name,
layout=self.page_layout(),
Expand Down Expand Up @@ -58,7 +58,7 @@ def page_layout(self) -> dash.html.Div:
import webbrowser

# Create our Dash Application
app = dash.Dash(
my_app = dash.Dash(
__name__,
title="SageWorks Dashboard",
use_pages=True,
Expand All @@ -67,14 +67,14 @@ def page_layout(self) -> dash.html.Div:
)

# For Multi-Page Applications, we need to create a 'page container' to hold all the pages
app.layout = html.Div([page_container])
my_app.layout = html.Div([page_container])

# Create the Plugin Page and call page_setup
plugin_page = PluginPage2()
plugin_page.page_setup(app)
plugin_page.page_setup(my_app)

# Open the browser to the plugin page
webbrowser.open("http://localhost:8000/plugin_2")

# Note: This 'main' is purely for running/testing locally
app.run(host="0.0.0.0", port=8000, debug=True)
my_app.run(host="localhost", port=8000, debug=True)
24 changes: 12 additions & 12 deletions examples/plugins/pages/plugin_page_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def page_setup(self, app: dash.Dash):
self.details_component = self.model_details.create_component("my_model_details")
self.plot_component = self.model_plot.create_component("my_model_plot")

# Register this page with Dash and set up the layout (required)
# Register this page with Dash and set up the layout
register_page(
__name__,
"plugin",
path="/plugin_3",
name=self.page_name,
layout=self.page_layout(),
Expand All @@ -54,7 +54,7 @@ def page_setup(self, app: dash.Dash):
self.table_component.data = models.to_dict("records")

# Register the callbacks
self.register_callbacks()
self.register_callbacks(app)
self.model_details.register_callbacks("my_model_table")

def page_layout(self) -> dash.html.Div:
Expand All @@ -73,16 +73,16 @@ def page_layout(self) -> dash.html.Div:
)
return layout

def register_callbacks(self):
def register_callbacks(self, app: dash.Dash):
"""Register the callbacks for the page"""

@callback(
@app.callback(
Output("my_model_plot", "figure"),
Input("my_model_details-dropdown", "value"),
[State("my_model_table", "data"), State("my_model_table", "derived_viewport_selected_row_ids")],
[Input("my_model_details-dropdown", "value"), Input("my_model_table", "derived_viewport_selected_row_ids")],
State("my_model_table", "data"),
prevent_initial_call=True,
)
def generate_model_plot_figure(inference_run, table_data, selected_rows):
def generate_model_plot_figure(inference_run, selected_rows, table_data):
# Check for no selected rows
if not selected_rows or selected_rows[0] is None:
return no_update
Expand All @@ -104,7 +104,7 @@ def generate_model_plot_figure(inference_run, table_data, selected_rows):
import webbrowser

# Create our Dash Application
app = dash.Dash(
my_app = dash.Dash(
__name__,
title="SageWorks Dashboard",
use_pages=True,
Expand All @@ -114,14 +114,14 @@ def generate_model_plot_figure(inference_run, table_data, selected_rows):
)

# For Multi-Page Applications, we need to create a 'page container' to hold all the pages
app.layout = html.Div([page_container])
my_app.layout = html.Div([page_container])

# Create the Plugin Page and call page_setup
plugin_page = PluginPage3()
plugin_page.page_setup(app)
plugin_page.page_setup(my_app)

# Open the browser to the plugin page
webbrowser.open("http://localhost:8000/plugin_3")

# Note: This 'main' is purely for running/testing locally
app.run(host="0.0.0.0", port=8000, debug=True)
my_app.run(host="localhost", port=8000, debug=True)
42 changes: 20 additions & 22 deletions examples/plugins/web_components/custom_plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""A Custom plugin component"""

from dash import dcc
import plotly.graph_objects as go


# SageWorks Imports
Expand All @@ -13,7 +12,7 @@ class CustomPlugin(PluginInterface):
"""CustomPlugin Component"""

"""Initialize this Plugin Component Class with required attributes"""
plugin_page = PluginPage.CUSTOM
auto_load_page = PluginPage.CUSTOM
plugin_input_type = PluginInputType.MODEL

def create_component(self, component_id: str) -> dcc.Graph:
Expand All @@ -23,34 +22,33 @@ def create_component(self, component_id: str) -> dcc.Graph:
Returns:
dcc.Graph: The EndpointTurbo Component
"""
return dcc.Graph(id=component_id, figure=self.display_text("Waiting for Data..."))
self.component_id = component_id
self.container = dcc.Graph(id=component_id, figure=self.display_text("Waiting for Data..."))

# Fill in content slots
self.slots = [(self.component_id, "figure")]

# Return the container
return self.container

def update_contents(self, model: Model, **kwargs) -> list:
"""Update the CustomPlugin contents
def update_contents(self, model: Model, **kwargs) -> go.Figure:
"""Create a CustomPlugin Figure
Args:
model (Model): An instantiated Endpoint object
**kwargs: Additional keyword arguments (unused)
Returns:
go.Figure: A Plotly Figure object
list: A list of the updated contents (children)
"""
model_name = f"Model: {model.uuid}"
return self.display_text(model_name)
text_figure = self.display_text(model_name, figure_height=100)
return [text_figure]


if __name__ == "__main__":
# This class takes in a model object

# Instantiate an Endpoint
my_model = Model("abalone-regression")

# Instantiate the EndpointTurbo class
plugin = CustomPlugin()

# Generate the figure
fig = plugin.update_contents(my_model)

# Apply dark theme
fig.update_layout(template="plotly_dark")
# A Unit Test for the Plugin
from sageworks.web_components.plugin_unit_test import PluginUnitTest

# Show the figure
fig.show()
# Run the Unit Test on the Plugin
PluginUnitTest(CustomPlugin).run()
Loading

0 comments on commit e85d3f6

Please sign in to comment.