diff --git a/applications/aws_dashboard/assets/favicon.ico b/applications/aws_dashboard/assets/favicon.ico
index 7e47cd163..185c59f58 100644
Binary files a/applications/aws_dashboard/assets/favicon.ico and b/applications/aws_dashboard/assets/favicon.ico differ
diff --git a/applications/aws_dashboard/pages/endpoints/callbacks.py b/applications/aws_dashboard/pages/endpoints/callbacks.py
index 01bd5914e..cd22319e3 100644
--- a/applications/aws_dashboard/pages/endpoints/callbacks.py
+++ b/applications/aws_dashboard/pages/endpoints/callbacks.py
@@ -1,13 +1,16 @@
"""Callbacks for the Endpoints Subpage Web User Interface"""
+import logging
from dash import Dash, no_update
from dash.dependencies import Input, Output, State
# SageWorks Imports
from sageworks.views.endpoint_web_view import EndpointWebView
-from sageworks.web_components import table, endpoint_metric_plots
+from sageworks.web_components import table, endpoint_metric_plots, plugin_callbacks
from sageworks.utils.pandas_utils import deserialize_aws_broker_data
-from sageworks.api.endpoint import Endpoint
+
+# Get the SageWorks logger
+log = logging.getLogger("sageworks")
def update_endpoints_table(app: Dash):
@@ -80,25 +83,8 @@ def generate_endpoint_details_figures(selected_rows, table_data):
return endpoint_metrics_figure
-# Updates the plugin component when a endpoint row is selected
-def update_plugin(app: Dash, plugin):
- @app.callback(
- Output(plugin.component_id(), "figure"),
- Input("endpoints_table", "derived_viewport_selected_row_ids"),
- State("endpoints_table", "data"),
- prevent_initial_call=True,
- )
- def update_callback(selected_rows, table_data):
- # Check for no selected rows
- if not selected_rows or selected_rows[0] is None:
- return no_update
-
- # Get the selected row data and grab the uuid
- selected_row_data = table_data[selected_rows[0]]
- endpoint_uuid = selected_row_data["uuid"]
-
- # Instantiate the Endpoint and send it to the plugin
- endpoint = Endpoint(endpoint_uuid, legacy=True)
-
- # Instantiate the Endpoint and send it to the plugin
- return plugin.update_contents(endpoint)
+# Updates the plugin components when a model row is selected
+def update_plugins(plugins):
+ # Setup the inputs for the plugins and register the callbacks
+ endpoint_inputs = [Input("endpoints_table", "derived_viewport_selected_row_ids"), State("endpoints_table", "data")]
+ plugin_callbacks.register_callbacks(plugins, endpoint_inputs, "endpoint")
diff --git a/applications/aws_dashboard/pages/endpoints/layout.py b/applications/aws_dashboard/pages/endpoints/layout.py
index c4780a0cb..920011e2a 100644
--- a/applications/aws_dashboard/pages/endpoints/layout.py
+++ b/applications/aws_dashboard/pages/endpoints/layout.py
@@ -25,6 +25,7 @@ def endpoints_layout(
[
html.H2("SageWorks: Endpoints"),
dbc.Row(style={"padding": "30px 0px 0px 0px"}),
+ html.Div(id="dev_null", style={"display": "none"}), # Output for callbacks without outputs
]
),
# A table that lists out all the Models
diff --git a/applications/aws_dashboard/pages/endpoints/page.py b/applications/aws_dashboard/pages/endpoints/page.py
index 7a8cd7f33..d13ede6a4 100644
--- a/applications/aws_dashboard/pages/endpoints/page.py
+++ b/applications/aws_dashboard/pages/endpoints/page.py
@@ -52,7 +52,7 @@
# Add the plugins to the components dictionary
for plugin in plugins:
- component_id = plugin.component_id()
+ component_id = plugin.generate_component_id()
components[component_id] = plugin.create_component(component_id)
# Set up our layout (Dash looks for a var called layout)
@@ -67,6 +67,5 @@
callbacks.table_row_select(app, "endpoints_table")
callbacks.update_endpoint_metrics(app, endpoint_broker)
-# For each plugin, set up a callback to update the plugin figure
-for plugin in plugins:
- callbacks.update_plugin(app, plugin)
+# For all the plugins we have we'll call their update_contents method
+callbacks.update_plugins(plugins)
diff --git a/applications/aws_dashboard/pages/models/callbacks.py b/applications/aws_dashboard/pages/models/callbacks.py
index 58c829cdd..a1237bb17 100644
--- a/applications/aws_dashboard/pages/models/callbacks.py
+++ b/applications/aws_dashboard/pages/models/callbacks.py
@@ -1,16 +1,17 @@
"""Callbacks for the Model Subpage Web User Interface"""
+import logging
from dash import Dash, no_update
from dash.dependencies import Input, Output, State
# SageWorks Imports
-from sageworks.web_components import (
- table,
- model_plot,
-)
+from sageworks.web_components import table, model_plot, plugin_callbacks
from sageworks.utils.pandas_utils import deserialize_aws_broker_data
from sageworks.api.model import Model
+# Get the SageWorks logger
+log = logging.getLogger("sageworks")
+
def update_models_table(app: Dash):
@app.callback(
@@ -76,23 +77,12 @@ def generate_model_plot_figure(inference_run, table_data, selected_rows):
return model_plot_fig
-# Updates the plugin component when a model row is selected
-def update_plugin(app: Dash, plugin):
- @app.callback(
- Output(plugin.component_id(), "figure"),
- [Input("model_details-dropdown", "value"), Input("models_table", "derived_viewport_selected_row_ids")],
+# Updates the plugin components when a model row is selected
+def update_plugins(plugins):
+ # Setup the inputs for the plugins and register the callbacks
+ model_inputs = [
+ Input("model_details-dropdown", "value"),
+ Input("models_table", "derived_viewport_selected_row_ids"),
State("models_table", "data"),
- prevent_initial_call=True,
- )
- def update_plugin_figure(inference_run, selected_rows, table_data):
- # Check for no selected rows
- if not selected_rows or selected_rows[0] is None:
- return no_update
-
- # Get the selected row data and grab the uuid
- selected_row_data = table_data[selected_rows[0]]
- model_uuid = selected_row_data["uuid"]
-
- # Instantiate the Model and send it to the plugin
- model = Model(model_uuid, legacy=True)
- return plugin.update_contents(model, inference_run=inference_run)
+ ]
+ plugin_callbacks.register_callbacks(plugins, model_inputs, "model")
diff --git a/applications/aws_dashboard/pages/models/layout.py b/applications/aws_dashboard/pages/models/layout.py
index 68a4ddd0d..8184a8f39 100644
--- a/applications/aws_dashboard/pages/models/layout.py
+++ b/applications/aws_dashboard/pages/models/layout.py
@@ -25,6 +25,7 @@ def models_layout(
[
html.H2("SageWorks: Models"),
dbc.Row(style={"padding": "30px 0px 0px 0px"}),
+ html.Div(id="dev_null", style={"display": "none"}), # Output for callbacks without outputs
]
),
# A table that lists out all the Models
diff --git a/applications/aws_dashboard/pages/models/page.py b/applications/aws_dashboard/pages/models/page.py
index b1dd626e9..cf3d26e26 100644
--- a/applications/aws_dashboard/pages/models/page.py
+++ b/applications/aws_dashboard/pages/models/page.py
@@ -52,7 +52,7 @@
# Add the plugins to the components dictionary
for plugin in plugins:
- component_id = plugin.component_id()
+ component_id = plugin.generate_component_id()
components[component_id] = plugin.create_component(component_id)
# Set up our layout (Dash looks for a var called layout)
@@ -67,6 +67,5 @@
callbacks.table_row_select(app, "models_table")
callbacks.update_model_plot_component(app)
-# For each plugin, set up a callback to update the plugin figure
-for plugin in plugins:
- callbacks.update_plugin(app, plugin)
+# Set up callbacks for all the plugins
+callbacks.update_plugins(plugins)
diff --git a/examples/plugins/pages/plugin_page_1.py b/examples/plugins/pages/plugin_page_1.py
index 1cfcbad2e..4288d2217 100644
--- a/examples/plugins/pages/plugin_page_1.py
+++ b/examples/plugins/pages/plugin_page_1.py
@@ -17,7 +17,7 @@ def page_setup(self, app: dash.Dash):
# Register this page with Dash and set up the layout (required)
register_page(
- __name__,
+ "plugin",
path="/plugin_1",
name=self.page_name,
layout=self.page_layout(),
@@ -34,7 +34,7 @@ def page_layout(self) -> dash.html.Div:
import webbrowser
# Create our Dash Application
- app = dash.Dash(
+ my_app = dash.Dash(
__name__,
title="SageWorks Dashboard",
use_pages=True,
@@ -43,14 +43,14 @@ def page_layout(self) -> dash.html.Div:
)
# For Multi-Page Applications, we need to create a 'page container' to hold all the pages
- app.layout = html.Div([page_container])
+ my_app.layout = html.Div([page_container])
# Create the Plugin Page and call page_setup
plugin_page = PluginPage1()
- plugin_page.page_setup(app)
+ plugin_page.page_setup(my_app)
# Open the browser to the plugin page
webbrowser.open("http://localhost:8000/plugin_1")
# Note: This 'main' is purely for running/testing locally
- app.run(host="0.0.0.0", port=8000, debug=True)
+ my_app.run(host="localhost", port=8000, debug=True)
diff --git a/examples/plugins/pages/plugin_page_2.py b/examples/plugins/pages/plugin_page_2.py
index 96fb6321b..537cbf361 100644
--- a/examples/plugins/pages/plugin_page_2.py
+++ b/examples/plugins/pages/plugin_page_2.py
@@ -10,7 +10,7 @@
class PluginPage2:
- """Plugin Page: A SageWorks Plugin Web Interface"""
+ """Plugin Page: A SageWorks Plugin Page Interface"""
def __init__(self):
"""Initialize the Plugin Page"""
@@ -27,9 +27,9 @@ def page_setup(self, app: dash.Dash):
"my_model_table", header_color="rgb(60, 60, 60)", row_select="single", max_height=400
)
- # Register this page with Dash and set up the layout (required)
+ # Register this page with Dash and set up the layout
register_page(
- __name__,
+ "plugin",
path="/plugin_2",
name=self.page_name,
layout=self.page_layout(),
@@ -58,7 +58,7 @@ def page_layout(self) -> dash.html.Div:
import webbrowser
# Create our Dash Application
- app = dash.Dash(
+ my_app = dash.Dash(
__name__,
title="SageWorks Dashboard",
use_pages=True,
@@ -67,14 +67,14 @@ def page_layout(self) -> dash.html.Div:
)
# For Multi-Page Applications, we need to create a 'page container' to hold all the pages
- app.layout = html.Div([page_container])
+ my_app.layout = html.Div([page_container])
# Create the Plugin Page and call page_setup
plugin_page = PluginPage2()
- plugin_page.page_setup(app)
+ plugin_page.page_setup(my_app)
# Open the browser to the plugin page
webbrowser.open("http://localhost:8000/plugin_2")
# Note: This 'main' is purely for running/testing locally
- app.run(host="0.0.0.0", port=8000, debug=True)
+ my_app.run(host="localhost", port=8000, debug=True)
diff --git a/examples/plugins/pages/plugin_page_3.py b/examples/plugins/pages/plugin_page_3.py
index 348895b2f..35e5ee2c1 100644
--- a/examples/plugins/pages/plugin_page_3.py
+++ b/examples/plugins/pages/plugin_page_3.py
@@ -38,9 +38,9 @@ def page_setup(self, app: dash.Dash):
self.details_component = self.model_details.create_component("my_model_details")
self.plot_component = self.model_plot.create_component("my_model_plot")
- # Register this page with Dash and set up the layout (required)
+ # Register this page with Dash and set up the layout
register_page(
- __name__,
+ "plugin",
path="/plugin_3",
name=self.page_name,
layout=self.page_layout(),
@@ -54,7 +54,7 @@ def page_setup(self, app: dash.Dash):
self.table_component.data = models.to_dict("records")
# Register the callbacks
- self.register_callbacks()
+ self.register_callbacks(app)
self.model_details.register_callbacks("my_model_table")
def page_layout(self) -> dash.html.Div:
@@ -73,16 +73,16 @@ def page_layout(self) -> dash.html.Div:
)
return layout
- def register_callbacks(self):
+ def register_callbacks(self, app: dash.Dash):
"""Register the callbacks for the page"""
- @callback(
+ @app.callback(
Output("my_model_plot", "figure"),
- Input("my_model_details-dropdown", "value"),
- [State("my_model_table", "data"), State("my_model_table", "derived_viewport_selected_row_ids")],
+ [Input("my_model_details-dropdown", "value"), Input("my_model_table", "derived_viewport_selected_row_ids")],
+ State("my_model_table", "data"),
prevent_initial_call=True,
)
- def generate_model_plot_figure(inference_run, table_data, selected_rows):
+ def generate_model_plot_figure(inference_run, selected_rows, table_data):
# Check for no selected rows
if not selected_rows or selected_rows[0] is None:
return no_update
@@ -104,7 +104,7 @@ def generate_model_plot_figure(inference_run, table_data, selected_rows):
import webbrowser
# Create our Dash Application
- app = dash.Dash(
+ my_app = dash.Dash(
__name__,
title="SageWorks Dashboard",
use_pages=True,
@@ -114,14 +114,14 @@ def generate_model_plot_figure(inference_run, table_data, selected_rows):
)
# For Multi-Page Applications, we need to create a 'page container' to hold all the pages
- app.layout = html.Div([page_container])
+ my_app.layout = html.Div([page_container])
# Create the Plugin Page and call page_setup
plugin_page = PluginPage3()
- plugin_page.page_setup(app)
+ plugin_page.page_setup(my_app)
# Open the browser to the plugin page
webbrowser.open("http://localhost:8000/plugin_3")
# Note: This 'main' is purely for running/testing locally
- app.run(host="0.0.0.0", port=8000, debug=True)
+ my_app.run(host="localhost", port=8000, debug=True)
diff --git a/examples/plugins/web_components/custom_plugin.py b/examples/plugins/web_components/custom_plugin.py
index fc9919e99..90c434bc6 100644
--- a/examples/plugins/web_components/custom_plugin.py
+++ b/examples/plugins/web_components/custom_plugin.py
@@ -1,7 +1,6 @@
"""A Custom plugin component"""
from dash import dcc
-import plotly.graph_objects as go
# SageWorks Imports
@@ -13,7 +12,7 @@ class CustomPlugin(PluginInterface):
"""CustomPlugin Component"""
"""Initialize this Plugin Component Class with required attributes"""
- plugin_page = PluginPage.CUSTOM
+ auto_load_page = PluginPage.CUSTOM
plugin_input_type = PluginInputType.MODEL
def create_component(self, component_id: str) -> dcc.Graph:
@@ -23,34 +22,33 @@ def create_component(self, component_id: str) -> dcc.Graph:
Returns:
dcc.Graph: The EndpointTurbo Component
"""
- return dcc.Graph(id=component_id, figure=self.display_text("Waiting for Data..."))
+ self.component_id = component_id
+ self.container = dcc.Graph(id=component_id, figure=self.display_text("Waiting for Data..."))
+
+ # Fill in content slots
+ self.slots = [(self.component_id, "figure")]
+
+ # Return the container
+ return self.container
+
+ def update_contents(self, model: Model, **kwargs) -> list:
+ """Update the CustomPlugin contents
- def update_contents(self, model: Model, **kwargs) -> go.Figure:
- """Create a CustomPlugin Figure
Args:
model (Model): An instantiated Endpoint object
**kwargs: Additional keyword arguments (unused)
+
Returns:
- go.Figure: A Plotly Figure object
+ list: A list of the updated contents (children)
"""
model_name = f"Model: {model.uuid}"
- return self.display_text(model_name)
+ text_figure = self.display_text(model_name, figure_height=100)
+ return [text_figure]
if __name__ == "__main__":
- # This class takes in a model object
-
- # Instantiate an Endpoint
- my_model = Model("abalone-regression")
-
- # Instantiate the EndpointTurbo class
- plugin = CustomPlugin()
-
- # Generate the figure
- fig = plugin.update_contents(my_model)
-
- # Apply dark theme
- fig.update_layout(template="plotly_dark")
+ # A Unit Test for the Plugin
+ from sageworks.web_components.plugin_unit_test import PluginUnitTest
- # Show the figure
- fig.show()
+ # Run the Unit Test on the Plugin
+ PluginUnitTest(CustomPlugin).run()
diff --git a/examples/plugins/web_components/endpoint_plugin.py b/examples/plugins/web_components/endpoint_plugin.py
index 0b064bd23..26760d39c 100644
--- a/examples/plugins/web_components/endpoint_plugin.py
+++ b/examples/plugins/web_components/endpoint_plugin.py
@@ -1,7 +1,7 @@
"""An example Endpoint plugin component"""
+import logging
from dash import dcc
-import plotly.graph_objects as go
# SageWorks Imports
@@ -9,11 +9,15 @@
from sageworks.api.endpoint import Endpoint
+# Get the SageWorks logger
+log = logging.getLogger("sageworks")
+
+
class MyEndpointPlugin(PluginInterface):
"""MyEndpointPlugin Component"""
"""Initialize this Plugin Component Class with required attributes"""
- plugin_page = PluginPage.ENDPOINT
+ auto_load_page = PluginPage.ENDPOINT
plugin_input_type = PluginInputType.ENDPOINT
def create_component(self, component_id: str) -> dcc.Graph:
@@ -23,34 +27,36 @@ def create_component(self, component_id: str) -> dcc.Graph:
Returns:
dcc.Graph: The EndpointTurbo Component
"""
- return dcc.Graph(id=component_id, figure=self.display_text("Waiting for Data..."))
+ self.component_id = component_id
+ self.container = dcc.Graph(id=component_id, figure=self.display_text("Waiting for Data..."))
+
+ # Fill in content slots
+ self.slots = [(self.component_id, "figure")]
+
+ # Return the container
+ return self.container
+
+ def update_contents(self, endpoint: Endpoint, **kwargs) -> list:
+ """Create a Endpoint Plugin Figure
- def update_contents(self, endpoint: Endpoint, **kwargs) -> go.Figure:
- """Create a CustomPlugin Figure
Args:
endpoint (Endpoint): An instantiated Endpoint object
**kwargs: Additional keyword arguments (unused)
+
Returns:
- go.Figure: A Plotly Figure object
+ list: A list of the updated contents (children)
"""
+ log.important(f"Updating Model Plugin with Model: {endpoint.uuid} and kwargs: {kwargs}")
endpoint_name = f"Endpoint: {endpoint.uuid}"
- return self.display_text(endpoint_name, figure_height=200)
+ text_figure = self.display_text(endpoint_name, figure_height=100)
+ # Return the updated contents
+ return [text_figure]
-if __name__ == "__main__":
- # This class takes in a model object
-
- # Instantiate an Endpoint
- my_endpoint = Endpoint("abalone-regression-end")
-
- # Instantiate the EndpointTurbo class
- plugin = MyEndpointPlugin()
- # Generate the figure
- fig = plugin.update_contents(my_endpoint)
-
- # Apply dark theme
- fig.update_layout(template="plotly_dark")
+if __name__ == "__main__":
+ # A Unit Test for the Plugin
+ from sageworks.web_components.plugin_unit_test import PluginUnitTest
- # Show the figure
- fig.show()
+ # Run the Unit Test on the Plugin
+ PluginUnitTest(MyEndpointPlugin, test_type="endpoint").run()
diff --git a/examples/plugins/web_components/endpoint_turbo.py b/examples/plugins/web_components/endpoint_turbo.py
index 0906cbcb2..68d516efc 100644
--- a/examples/plugins/web_components/endpoint_turbo.py
+++ b/examples/plugins/web_components/endpoint_turbo.py
@@ -13,7 +13,7 @@ class EndpointTurbo(PluginInterface):
"""EndpointTurbo Component"""
"""Initialize this Plugin Component Class with required attributes"""
- plugin_page = PluginPage.ENDPOINT
+ auto_load_page = PluginPage.ENDPOINT
plugin_input_type = PluginInputType.ENDPOINT
def create_component(self, component_id: str) -> dcc.Graph:
@@ -23,20 +23,26 @@ def create_component(self, component_id: str) -> dcc.Graph:
Returns:
dcc.Graph: The EndpointTurbo Component
"""
- return dcc.Graph(id=component_id, figure=self.display_text("Waiting for Data..."))
+ self.component_id = component_id
+ self.container = dcc.Graph(id=component_id, figure=self.display_text("Waiting for Data..."))
+
+ # Fill in content slots
+ self.slots = [(self.component_id, "figure")]
+
+ # Return the container
+ return self.container
+
+ def update_contents(self, endpoint: Endpoint, **kwargs) -> list:
+ """Update the contents for the plugin.
- def update_contents(self, endpoint: Endpoint, **kwargs) -> go.Figure:
- """Create a EndpointTurbo Figure for the numeric columns in the dataframe.
Args:
- endpoint (Endpoint): An instantiated Endpoint object
+ endpoint (Endpoint): An instantiated Model object
**kwargs: Additional keyword arguments (unused)
+
Returns:
- go.Figure: A Plotly Figure object
+ list: A list of the updated contents (children)
"""
- # Just to make sure we have the right endpoint object
- print(endpoint.summary())
-
data = [ # Portfolio (inner donut)
# Inner ring
go.Pie(
@@ -83,27 +89,18 @@ def update_contents(self, endpoint: Endpoint, **kwargs) -> go.Figure:
]
# Create the nested pie chart plot with custom settings
- fig = go.Figure(data=data)
- fig.update_layout(margin={"t": 10, "b": 10, "r": 10, "l": 10, "pad": 10}, height=400)
+ endpoint_name = f"Endpoint: {endpoint.uuid}"
+ turbo_figure = go.Figure(data=data)
+ turbo_figure.update_layout(
+ margin={"t": 30, "b": 10, "r": 10, "l": 10, "pad": 10}, title=endpoint_name, height=400
+ )
- return fig
+ # Return the updated contents
+ return [turbo_figure]
if __name__ == "__main__":
- # This class takes in model details and generates a EndpointTurbo
- from sageworks.api.endpoint import Endpoint
-
- # Instantiate an Endpoint
- end = Endpoint("abalone-regression-end")
-
- # Instantiate the EndpointTurbo class
- pie = EndpointTurbo()
-
- # Generate the figure
- fig = pie.update_contents(end)
-
- # Apply dark theme
- fig.update_layout(template="plotly_dark")
+ from sageworks.web_components.plugin_unit_test import PluginUnitTest
- # Show the figure
- fig.show()
+ # Run the Unit Test on the Plugin
+ PluginUnitTest(EndpointTurbo).run()
diff --git a/examples/plugins/web_components/model_markdown.py b/examples/plugins/web_components/model_markdown.py
new file mode 100644
index 000000000..a3427c2c2
--- /dev/null
+++ b/examples/plugins/web_components/model_markdown.py
@@ -0,0 +1,84 @@
+"""A Markdown Plugin Example for details/information about Models"""
+
+import logging
+
+# Dash Imports
+from dash import html, dcc
+
+# SageWorks Imports
+from sageworks.api import Model
+from sageworks.web_components.plugin_interface import PluginInterface, PluginPage, PluginInputType
+
+# Get the SageWorks logger
+log = logging.getLogger("sageworks")
+
+
+class MyModelMarkdown(PluginInterface):
+ """MyModelMarkdown Component"""
+
+ """Initialize this Plugin Component Class with required attributes"""
+ auto_load_page = PluginPage.MODEL
+ plugin_input_type = PluginInputType.MODEL
+
+ def create_component(self, component_id: str) -> html.Div:
+ """Create a Model Markdown Component without any data
+
+ Args:
+ component_id (str): The ID of the web component
+ Returns:
+ dcc.Graph: The EndpointTurbo Component
+ """
+ self.component_id = component_id
+ self.container = html.Div(
+ id=self.component_id,
+ children=[
+ html.H3(id=f"{self.component_id}-header", children="Model: Loading..."),
+ dcc.Markdown(id=f"{self.component_id}-details"),
+ ],
+ )
+
+ # Fill in content slots
+ self.slots = [
+ (f"{self.component_id}-header", "children"),
+ (f"{self.component_id}-details", "children"),
+ ]
+
+ # Return the container
+ return self.container
+
+ def update_contents(self, model: Model, **kwargs) -> list:
+ """Update the contents for this plugin component
+
+ Args:
+ model (Model): An instantiated Model object
+ **kwargs: Additional keyword arguments (unused)
+
+ Returns:
+ list: A list of the updated contents (children)
+ """
+ log.important(f"Updating Model Markdown Plugin with Model: {model.uuid} and kwargs: {kwargs}")
+
+ # Update the html header
+ header = f"Model: {model.uuid}"
+
+ # Make Markdown for the model summary
+ summary = model.summary()
+ markdown = ""
+ for key, value in summary.items():
+
+ # Chop off the "sageworks_" prefix
+ key = key.replace("sageworks_", "")
+
+ # Add to markdown string
+ markdown += f"**{key}:** {value} \n"
+
+ # Return the updated contents (must match slots)
+ return header, markdown
+
+
+# Unit Test for the Plugin
+if __name__ == "__main__":
+ from sageworks.web_components.plugin_unit_test import PluginUnitTest
+
+ # Run the Unit Test on the Plugin
+ PluginUnitTest(MyModelMarkdown).run()
diff --git a/examples/plugins/web_components/model_plugin.py b/examples/plugins/web_components/model_plugin.py
index a57ee6f12..c5a4e688c 100644
--- a/examples/plugins/web_components/model_plugin.py
+++ b/examples/plugins/web_components/model_plugin.py
@@ -1,64 +1,68 @@
"""An Example Model plugin component"""
+import logging
from dash import dcc
-import plotly.graph_objects as go
import random
-
+import plotly.graph_objects as go
# SageWorks Imports
from sageworks.web_components.plugin_interface import PluginInterface, PluginPage, PluginInputType
from sageworks.api.model import Model
+# Get the SageWorks logger
+log = logging.getLogger("sageworks")
+
-class MyModelPlugin(PluginInterface):
+class ModelPlugin(PluginInterface):
"""MyModelPlugin Component"""
"""Initialize this Plugin Component Class with required attributes"""
- plugin_page = PluginPage.MODEL
+ auto_load_page = PluginPage.MODEL
plugin_input_type = PluginInputType.MODEL
def create_component(self, component_id: str) -> dcc.Graph:
- """Create a EndpointTurbo Component without any data.
+ """Create a Model Component without any data.
Args:
component_id (str): The ID of the web component
Returns:
dcc.Graph: The EndpointTurbo Component
"""
- return dcc.Graph(id=component_id, figure=self.display_text("Waiting for Data..."))
+ self.component_id = component_id
+ self.container = dcc.Graph(id=component_id, figure=self.display_text("Waiting for Data..."))
+
+ # Fill in content slots
+ self.slots = [(self.component_id, "figure")]
+
+ # Return the container
+ return self.container
+
+ def update_contents(self, model: Model, **kwargs) -> list:
+ """Update the contents for the plugin.
- def update_contents(self, model: Model, **kwargs) -> go.Figure:
- """Create a Figure for the plugin.
Args:
model (Model): An instantiated Model object
**kwargs: Additional keyword arguments (unused)
+
Returns:
- go.Figure: A Plotly Figure object
+ list: A list of the updated contents (children)
"""
+ log.important(f"Updating Model Plugin with Model: {model.uuid} and kwargs: {kwargs}")
model_name = f"Model: {model.uuid}"
# Generate random values for the pie chart
- pie_values = [random.randint(10, 30) for _ in range(3)]
+ pie_values = [random.randint(10, 30) for _ in range(4)]
# Create a pie chart with the endpoint name as the title
- fig = go.Figure(data=[go.Pie(labels=["A", "B", "C"], values=pie_values)], layout=go.Layout(title=model_name))
- return fig
-
-
-if __name__ == "__main__":
- # This class takes in model details and generates a pie chart
- from sageworks.api.model import Model
-
- # Instantiate an Endpoint
- model = Model("abalone-regression")
+ pie_figure = go.Figure(
+ data=[go.Pie(labels=["A", "B", "C", "D"], values=pie_values)], layout=go.Layout(title=model_name)
+ )
- # Instantiate the EndpointTurbo class
- my_plugin = MyModelPlugin()
+ # Return the updated contents
+ return [pie_figure]
- # Generate the figure
- fig = my_plugin.update_contents(model)
- # Apply dark theme
- fig.update_layout(template="plotly_dark")
+if __name__ == "__main__":
+ from sageworks.web_components.plugin_unit_test import PluginUnitTest
- # Show the figure
- fig.show()
+ # Run the Unit Test on the Plugin
+ PluginUnitTest(ModelPlugin).run()
diff --git a/src/sageworks/utils/plugin_manager.py b/src/sageworks/utils/plugin_manager.py
index 190b9c2b9..6bf2a0541 100644
--- a/src/sageworks/utils/plugin_manager.py
+++ b/src/sageworks/utils/plugin_manager.py
@@ -167,7 +167,7 @@ def get_list_of_web_plugins(self, plugin_page: PluginPage = None) -> List[Any]:
plugin_classes = [
self.plugins["web_components"][x]
for x in self.plugins["web_components"]
- if self.plugins["web_components"][x].plugin_page == plugin_page
+ if self.plugins["web_components"][x].auto_load_page == plugin_page
]
return [x() for x in plugin_classes]
diff --git a/src/sageworks/web_components/ag_table.py b/src/sageworks/web_components/ag_table.py
new file mode 100644
index 000000000..a8284cc70
--- /dev/null
+++ b/src/sageworks/web_components/ag_table.py
@@ -0,0 +1,87 @@
+"""An Example Table plugin component using AG Grid"""
+
+import logging
+import pandas as pd
+from dash_ag_grid import AgGrid
+
+# SageWorks Imports
+from sageworks.web_components.plugin_interface import PluginInterface, PluginPage, PluginInputType
+
+# Get the SageWorks logger
+log = logging.getLogger("sageworks")
+
+
+class AGTable(PluginInterface):
+ """AGTable Component"""
+
+ """Initialize this Plugin Component Class with required attributes"""
+ auto_load_page = PluginPage.NONE
+ plugin_input_type = PluginInputType.MODEL_TABLE
+
+ def create_component(self, component_id: str) -> AgGrid:
+ """Create a Table Component without any data.
+ Args:
+ component_id (str): The ID of the web component
+ Returns:
+ AgGrid: The Table Component using AG Grid
+ """
+ self.component_id = component_id
+ self.container = AgGrid(
+ id=component_id,
+ # className="ag-theme-balham-dark",
+ columnSize="sizeToFit",
+ dashGridOptions={
+ "rowHeight": None,
+ "domLayout": "normal",
+ "rowSelection": "single",
+ "filter": True,
+ },
+ style={"maxHeight": "200px", "overflow": "auto"},
+ )
+
+ # Fill in content slots
+ self.slots = [
+ (self.component_id, "columnDefs"),
+ (self.component_id, "rowData"),
+ (self.component_id, "selectedRows"),
+ ]
+
+ # Output signals
+ self.signals = [
+ (self.component_id, "selectedRows"),
+ ]
+
+ # Return the container
+ return self.container
+
+ def update_contents(self, model_table: pd.DataFrame, **kwargs) -> list:
+ """Update the contents for the plugin.
+
+ Args:
+ model_table (pd.DataFrame): A DataFrame with the model table data
+ **kwargs: Additional keyword arguments (unused)
+
+ Returns:
+ list: A list of the updated contents (children)
+ """
+ log.important(f"Updating Table Plugin with a model table and kwargs: {kwargs}")
+
+ # Convert the DataFrame to a list of dictionaries for AG Grid
+ table_data = model_table.to_dict("records")
+
+ # Define column definitions based on the DataFrame
+ column_defs = [{"headerName": col, "field": col, "filter": "agTextColumnFilter"} for col in model_table.columns]
+
+ # Select the first row by default
+ selected_rows = model_table.head(1).to_dict("records")
+
+ # Return the column definitions and table data (must match the content slots)
+ return [column_defs, table_data, selected_rows]
+
+
+if __name__ == "__main__":
+ # Run the Unit Test for the Plugin
+ from sageworks.web_components.plugin_unit_test import PluginUnitTest
+
+ # Run the Unit Test on the Plugin
+ PluginUnitTest(AGTable).run()
diff --git a/src/sageworks/web_components/component_interface.py b/src/sageworks/web_components/component_interface.py
index 177e7b586..688abb4d9 100644
--- a/src/sageworks/web_components/component_interface.py
+++ b/src/sageworks/web_components/component_interface.py
@@ -1,16 +1,21 @@
"""An abstract class that defines the web component interface for SageWorks"""
+import logging
from abc import ABC, abstractmethod
from typing import Any, Union
+import textwrap
import re
from functools import wraps
-import logging
import plotly.graph_objects as go
-from dash import dcc, html, dash_table
+import pandas as pd
+from dash import dcc
+from dash.development.base_component import Component
# SageWorks Imports
from sageworks.api import DataSource, FeatureSet, Model, Endpoint
+log = logging.getLogger("sageworks")
+
class ComponentInterface(ABC):
"""A Abstract Web Component Interface
@@ -21,9 +26,13 @@ class ComponentInterface(ABC):
log = logging.getLogger("sageworks")
- SageworksObject = Union[DataSource, FeatureSet, Model, Endpoint]
- ComponentTypes = Union[dcc.Graph, dash_table.DataTable, dcc.Markdown, html.Div]
- ContentTypes = Union[go.Figure, str, None] # str = Markdown, None = No Update
+ SageworksObject = Union[DataSource, FeatureSet, Model, Endpoint, pd.DataFrame]
+
+ def __init__(self):
+ self.component_id = None
+ self.container = None
+ self.slots = []
+ self.signals = []
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
@@ -35,26 +44,30 @@ def __init_subclass__(cls, **kwargs):
cls.update_contents = update_contents_handler(cls.update_contents)
@abstractmethod
- def create_component(self, component_id: str, **kwargs: Any) -> ComponentTypes:
+ def create_component(self, component_id: str, **kwargs: Any) -> Component:
"""Create a Dash Component/Container without any data.
+
Args:
component_id (str): The ID of the web component
kwargs (Any): Any additional arguments to pass to the component
+
Returns:
- Union[dcc.Graph, dash_table.DataTable, dcc.Markdown, html.Div]: The Dash Web component
+ Component: A Dash Base Component
"""
pass
- def update_contents(self, data_object: SageworksObject) -> ContentTypes:
+ def update_contents(self, data_object: SageworksObject) -> list:
"""Update the contents of the component/container
+
Args:
- data_object (sageworks_object): The instantiated data object for the plugin type.
+ data_object (sageworks_object/dataframe): A SageWorks object or DataFrame
+
Returns:
- Union[go.Figure, str]: A Plotly Figure or a Markdown string
+ list: A list of the updated contents for EACH slot in the plugin
"""
pass
- def component_id(self) -> str:
+ def generate_component_id(self) -> str:
"""This helper method returns the component ID for the component
Returns:
str: An auto generated component ID
@@ -106,7 +119,8 @@ def wrapper(*args, **kwargs):
# Get the class name of the plugin
class_name = args[0].__class__.__name__ if args else "UnknownPlugin"
error_info = f"{class_name} Crashed: {e.__class__.__name__}: {e}"
- figure = ComponentInterface.display_text(error_info, figure_height=100, font_size=16)
+ error_info = "
".join(textwrap.wrap(error_info, width=120))
+ figure = ComponentInterface.display_text(error_info, figure_height=200, font_size=14)
return dcc.Graph(id="error", figure=figure)
return wrapper
@@ -114,13 +128,30 @@ def wrapper(*args, **kwargs):
def update_contents_handler(func):
@wraps(func)
- def wrapper(*args, **kwargs):
+ def wrapper(self, *args, **kwargs):
try:
- return func(*args, **kwargs)
+ return func(self, *args, **kwargs)
except Exception as e:
# Get the class name of the plugin
- class_name = args[0].__class__.__name__ if args else "UnknownPlugin"
+ class_name = self.__class__.__name__
error_info = f"{class_name} Crashed: {e.__class__.__name__}: {e}"
- return ComponentInterface.display_text(error_info, figure_height=100, font_size=16)
+ log.critical(error_info)
+ error_info = "
".join(textwrap.wrap(error_info, width=120))
+ figure = ComponentInterface.display_text(error_info, figure_height=200, font_size=14)
+
+ # Prepare the error output to match the content_slots format
+ error_output = []
+ for component_id, property in self.slots:
+ if property == "figure":
+ error_output.append(figure)
+ elif property in ["children", "value", "data"]:
+ error_output.append(error_info)
+ elif property == "columnDefs":
+ error_output.append([{"headerName": "Crash", "field": "Crash"}])
+ elif property == "rowData":
+ error_output.append([{"Crash": error_info}])
+ else:
+ error_output.append(None) # Fallback for other properties
+ return error_output
return wrapper
diff --git a/src/sageworks/web_components/data_table.py b/src/sageworks/web_components/data_table.py
new file mode 100644
index 000000000..c9c92061e
--- /dev/null
+++ b/src/sageworks/web_components/data_table.py
@@ -0,0 +1,65 @@
+import logging
+import pandas as pd
+from dash import dash_table
+
+# SageWorks Imports
+from sageworks.web_components.plugin_interface import PluginInterface, PluginPage, PluginInputType
+
+# Get the SageWorks logger
+log = logging.getLogger("sageworks")
+
+
+class DataTable(PluginInterface):
+ """DataTable Component"""
+
+ """Initialize this Plugin Component Class with required attributes"""
+ auto_load_page = PluginPage.NONE
+ plugin_input_type = PluginInputType.MODEL_TABLE
+
+ def create_component(self, component_id: str) -> dash_table.DataTable:
+ """Create a Table Component without any data."""
+ self.component_id = component_id
+ self.container = dash_table.DataTable(
+ id=component_id,
+ columns=[],
+ data=[],
+ filter_action="native", # Enable filtering
+ sort_action="native", # Enable sorting
+ row_selectable="single", # Enable single row selection
+ selected_rows=[0], # Select the first row by default
+ style_table={"maxHeight": "200px", "overflow": "auto"}, # Table styling
+ )
+
+ # Fill in content slots
+ self.slots = [
+ (self.component_id, "columns"),
+ (self.component_id, "data"),
+ ]
+
+ # Output signals
+ self.signals = [
+ (self.component_id, "selected_rows"),
+ ]
+
+ return self.container
+
+ def update_contents(self, model_table: pd.DataFrame, **kwargs) -> list:
+ """Update the contents for the plugin."""
+ log.important(f"Updating DataTable Plugin with a model table and kwargs: {kwargs}")
+
+ # Convert the DataFrame to a list of dictionaries for DataTable
+ table_data = model_table.to_dict("records")
+
+ # Define column definitions based on the DataFrame
+ columns = [{"name": col, "id": col} for col in model_table.columns]
+
+ # Return the column definitions and table data (must match the content slots)
+ return [columns, table_data]
+
+
+if __name__ == "__main__":
+ # Run the Unit Test for the Plugin
+ from sageworks.web_components.plugin_unit_test import PluginUnitTest
+
+ # Run the Unit Test on the Plugin
+ PluginUnitTest(DataTable).run()
diff --git a/src/sageworks/web_components/plugin_callbacks.py b/src/sageworks/web_components/plugin_callbacks.py
new file mode 100644
index 000000000..d70c524f5
--- /dev/null
+++ b/src/sageworks/web_components/plugin_callbacks.py
@@ -0,0 +1,60 @@
+from dash import callback, Output
+from dash.exceptions import PreventUpdate
+import logging
+
+# SageWorks Imports
+from sageworks.api import Model, Endpoint
+
+log = logging.getLogger("sageworks")
+
+
+def register_callbacks(plugins, input_sources, object_type):
+ # Construct a list of Output objects dynamically based on the plugins' content_slots
+ outputs = [Output(component_id, property) for plugin in plugins for component_id, property in plugin.slots]
+
+ @callback(
+ outputs,
+ input_sources,
+ prevent_initial_call=True,
+ )
+ def update_plugin_contents(*args):
+ # Unpack the input arguments
+ if object_type == "model":
+ inference_run, selected_rows, table_data = args
+ else: # object_type == 'endpoint'
+ selected_rows, table_data = args
+
+ # Check for no selected rows
+ if not selected_rows or selected_rows[0] is None:
+ raise PreventUpdate
+
+ # Get the selected row data and grab the uuid
+ selected_row_data = table_data[selected_rows[0]]
+ object_uuid = selected_row_data["uuid"]
+
+ # Instantiate the object (Model or Endpoint)
+ if object_type == "model":
+ obj = Model(object_uuid, legacy=True)
+ else: # object_type == 'endpoint'
+ obj = Endpoint(object_uuid, legacy=True)
+
+ # Update the plugins and collect the updated properties for each slot
+ updated_properties = []
+ for plugin in plugins:
+ log.important(f"Updating Plugin: {plugin} with {object_type.capitalize()}: {object_uuid}")
+ if object_type == "model":
+ updated_contents = plugin.update_contents(obj, inference_run=inference_run)
+ else: # object_type == 'endpoint'
+ updated_contents = plugin.update_contents(obj)
+
+ # Assume that the length of contents matches the number of slots for the plugin
+ if len(updated_contents) != len(plugin.slots):
+ raise ValueError(
+ f"Plugin {plugin} has {len(updated_contents)} content values != {len(plugin.slots)} slots."
+ )
+
+ # Append each value from contents to the updated_properties list
+ updated_properties.extend(updated_contents)
+
+ # Return the updated properties for each slot
+ return updated_properties
diff --git a/src/sageworks/web_components/plugin_interface.py b/src/sageworks/web_components/plugin_interface.py
index d63540381..60f6b356b 100644
--- a/src/sageworks/web_components/plugin_interface.py
+++ b/src/sageworks/web_components/plugin_interface.py
@@ -4,19 +4,21 @@
from inspect import signature
from typing import Union, get_args
from enum import Enum
+from dash.development.base_component import Component
# Local Imports
from sageworks.web_components.component_interface import ComponentInterface
class PluginPage(Enum):
- """Plugin Page: Specify which page will autoload the plugin (CUSTOM = Don't autoload)"""
+ """Plugin Page: Specify which page will AUTO load the plugin (CUSTOM/NONE = Don't autoload)"""
DATA_SOURCE = "data_source"
FEATURE_SET = "feature_set"
MODEL = "model"
ENDPOINT = "endpoint"
CUSTOM = "custom"
+ NONE = "none"
class PluginInputType(Enum):
@@ -26,37 +28,36 @@ class PluginInputType(Enum):
FEATURE_SET = "feature_set"
MODEL = "model"
ENDPOINT = "endpoint"
+ MODEL_TABLE = "model_table"
class PluginInterface(ComponentInterface):
"""A Web Plugin Interface
Notes:
- - These methods are ^stateless^, all data should be passed through the
- arguments and the implementations should not reference 'self' variables
- The 'create_component' method must be implemented by the child class
- The 'update_contents' method must be implemented by the child class
"""
@abstractmethod
- def create_component(self, component_id: str) -> ComponentInterface.ComponentTypes:
+ def create_component(self, component_id: str) -> Component:
"""Create a Dash Component without any data.
+
Args:
component_id (str): The ID of the web component
+
Returns:
- Union[dcc.Graph, dash_table.DataTable, dcc.Markdown, html.Div] The Dash Web component
+ Component: A Dash Base Component
"""
pass
@abstractmethod
- def update_contents(
- self, data_object: ComponentInterface.SageworksObject, **kwargs
- ) -> ComponentInterface.ContentTypes:
- """Generate a figure from the data in the given dataframe.
+ def update_contents(self, data_object: ComponentInterface.SageworksObject, **kwargs) -> list:
+ """Generate the contents for the plugin component
Args:
data_object (sageworks_object): The instantiated data object for the plugin type.
**kwargs: Additional keyword arguments (plugins can define their own arguments)
Returns:
- Union[go.Figure, str]: A Plotly Figure or a Markdown string
+ list: A list of the updated contents for EACH slot in the plugin
"""
pass
@@ -66,20 +67,20 @@ def update_contents(
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
- # Ensure the subclass defines the required plugin_page and plugin_input_type
- if not hasattr(cls, "plugin_page") or not isinstance(cls.plugin_page, PluginPage):
- raise TypeError("Subclasses must define a 'plugin_page' of type PluginPage")
+ # Ensure the subclass defines the required auto_load_page and plugin_input_type
+ if not hasattr(cls, "auto_load_page") or not isinstance(cls.auto_load_page, PluginPage):
+ raise TypeError("Subclasses must define a 'auto_load_page' of type PluginPage")
if not hasattr(cls, "plugin_input_type") or not isinstance(cls.plugin_input_type, PluginInputType):
raise TypeError("Subclasses must define a 'plugin_input_type' of type PluginInputType")
- # If any base class method or parameter is missing from a subclass, or if a subclass method parameter is not
- # correctly typed a call of issubclass(subclass, cls) will return False, allowing runtime checks for plugins
- # The plugin loader calls issubclass(subclass, cls) to determine if the subclass is a valid plugin
+ # This subclass check ensures that a subclass of PluginInterface has all required attributes, methods,
+ # and signatures. It returns False any thing is incorrect enabling runtime validation for plugins.
+ # The plugin loader uses issubclass(subclass, cls) to verify plugin subclasses.
@classmethod
def __subclasshook__(cls, subclass):
if cls is PluginInterface:
# Check if the subclass has all the required attributes
- if not all(hasattr(subclass, attr) for attr in ("plugin_page", "plugin_input_type")):
+ if not all(hasattr(subclass, attr) for attr in ("auto_load_page", "plugin_input_type")):
cls.log.warning(f"Subclass {subclass.__name__} is missing required attributes")
return False
@@ -135,26 +136,23 @@ def _check_argument_types(cls, base_class_method, subclass_method):
if expected != actual:
return f"Expected argument type {expected} does not match actual argument type {actual}"
- # Check for **kwargs in the update_contents method
- # WIP
- """
- if base_class_method.__name__ == "update_contents":
- if not any(param.kind == param.VAR_KEYWORD for param in signature(subclass_method).parameters.values()):
- return "Expected **kwargs in update_contents method arguments, but it was not found."
- """
-
return None
@classmethod
def _check_return_type(cls, base_class_method, subclass_method):
- return_annotation = base_class_method.__annotations__["return"]
- expected_return_types = get_args(return_annotation)
- return_type = subclass_method.__annotations__.get("return", None)
+ method_name = base_class_method.__name__
+ actual_return_type = subclass_method.__annotations__.get("return", None)
+
+ if actual_return_type is None:
+ return "Missing return type annotation in subclass method."
- # Treat None as NoneType for return type comparison
- if return_type is None:
- return_type = type(None)
+ if method_name == "create_component":
+ if not issubclass(actual_return_type, Component):
+ return (
+ f"Incorrect return type for {method_name} (expected Component, got {actual_return_type.__name__})"
+ )
+ elif method_name == "update_contents":
+ if not (actual_return_type == list or (getattr(actual_return_type, "__origin__", None) is list)):
+ return f"Incorrect return type for {method_name} (expected list, got {actual_return_type.__name__})"
- if return_type not in expected_return_types:
- return f"Incorrect return type (expected one of {expected_return_types}, got {return_type})"
return None
diff --git a/src/sageworks/web_components/plugin_unit_test.py b/src/sageworks/web_components/plugin_unit_test.py
new file mode 100644
index 000000000..834eff3f8
--- /dev/null
+++ b/src/sageworks/web_components/plugin_unit_test.py
@@ -0,0 +1,73 @@
+import dash
+from dash import html, Output, Input
+
+# SageWorks Imports
+from sageworks.web_components.plugin_interface import PluginInterface, PluginInputType
+from sageworks.api import Model, Endpoint, Meta
+
+
+class PluginUnitTest:
+ def __init__(self, plugin_class):
+ """A class to unit test a PluginInterface class.
+
+ Args:
+ plugin_class (PluginInterface): The PluginInterface class to test
+ """
+ assert issubclass(plugin_class, PluginInterface), "Plugin class must be a subclass of PluginInterface"
+
+ # Get the input type of the plugin
+ plugin_input_type = plugin_class.plugin_input_type
+
+ # Instantiate the plugin
+ self.plugin = plugin_class()
+ self.component = self.plugin.create_component(f"{self.plugin.__class__.__name__.lower()}_test")
+
+ # Create the Dash app
+ self.app = dash.Dash(__name__)
+
+ # Set up the layout
+ layout_children = [self.component, html.Button("Update Plugin", id="update-button")]
+
+ # Signal output displays
+ layout_children.append(html.H3("Signals:"))
+ for component_id, property in self.plugin.signals:
+ # A Row with the component ID and property and an output div
+ layout_children.append(html.H4(f"Property: {property}"))
+ layout_children.append(html.Div(id=f"test-output-{component_id}-{property}"))
+
+ self.app.layout = html.Div(layout_children)
+
+ # Set up the test callback for updating the plugin
+ @self.app.callback(
+ [Output(component_id, property) for component_id, property in self.plugin.slots],
+ [Input("update-button", "n_clicks")],
+ prevent_initial_call=True,
+ )
+ def update_plugin_contents(n_clicks):
+ # Simulate updating the plugin with a new Model, Endpoint, or Model Table
+ if plugin_input_type == PluginInputType.MODEL:
+ model = Model("abalone-regression")
+ updated_contents = self.plugin.update_contents(model)
+ elif plugin_input_type == PluginInputType.ENDPOINT:
+ endpoint = Endpoint("abalone-regression-end")
+ updated_contents = self.plugin.update_contents(endpoint)
+ elif plugin_input_type == PluginInputType.MODEL_TABLE:
+ model_table = Meta().models()
+ updated_contents = self.plugin.update_contents(model_table)
+ else:
+ raise ValueError(f"Invalid test type: {plugin_input_type}")
+
+ # Return the updated contents based on the plugin's slots
+ return updated_contents
+
+ # Set up callbacks for displaying output signals
+ for component_id, property in self.plugin.signals:
+
+ @self.app.callback(
+ Output(f"test-output-{component_id}-{property}", "children"), Input(component_id, property)
+ )
+ def display_output_signal(signal_value):
+ return f"{signal_value}"
+
+ def run(self):
+ self.app.run_server(debug=True)
diff --git a/tests/delete_test_artifacts.py b/tests/delete_test_artifacts.py
index 20666e275..a8d505a70 100644
--- a/tests/delete_test_artifacts.py
+++ b/tests/delete_test_artifacts.py
@@ -63,6 +63,10 @@
if m.exists():
print("Deleting abalone-regression model...")
m.delete()
+ m = Model("abalone-regression-full")
+ if m.exists():
+ print("Deleting abalone-regression-full model...")
+ m.delete()
end = Endpoint("abalone-regression-end")
if end.exists():
print("Deleting abalone-regression-end...")
diff --git a/tests/plugin_tests/crashing_plugin.py b/tests/plugin_tests/crashing_plugin.py
index c12500901..4f77fbbbe 100644
--- a/tests/plugin_tests/crashing_plugin.py
+++ b/tests/plugin_tests/crashing_plugin.py
@@ -6,13 +6,14 @@
# SageWorks Imports
from sageworks.web_components.plugin_interface import PluginInterface, PluginPage, PluginInputType
+from sageworks.api.model import Model
class CrashingPlugin(PluginInterface):
"""CrashingPlugin Component"""
"""Initialize this Plugin Component Class with required attributes"""
- plugin_page = PluginPage.MODEL
+ auto_load_page = PluginPage.MODEL
plugin_input_type = PluginInputType.MODEL
def create_component(self, component_id: str) -> dcc.Graph:
@@ -24,16 +25,16 @@ def create_component(self, component_id: str) -> dcc.Graph:
"""
return dcc.Graph(id=component_id, figure=self.display_text("Waiting for Data..."))
- def update_contents(self, model_details: dict) -> go.Figure:
+ def update_contents(self, model: Model) -> go.Figure:
"""Create a CrashingPlugin Figure for the numeric columns in the dataframe.
Args:
- model_details (dict): The model details dictionary (see Model.details())
+ model (Model): A Model Object
Returns:
go.Figure: A Figure object containing the confusion matrix.
"""
# This is where the plugin crashes
- my_bad = model_details["bad_key"]
+ my_bad = model.summary()["bad_key"]
# Create the nested pie chart plot with custom settings
fig = go.Figure(my_bad)
diff --git a/tests/web_components/plugin_interface_test.py b/tests/web_components/plugin_interface_test.py
index c8612fe6e..4f45c0c4f 100644
--- a/tests/web_components/plugin_interface_test.py
+++ b/tests/web_components/plugin_interface_test.py
@@ -13,9 +13,12 @@ class CorrectPlugin(PluginInterface):
"""Subclass of PluginInterface with correct inputs and returns."""
"""Initialize this Plugin Component Class with required attributes"""
- plugin_page = PluginPage.MODEL
+ auto_load_page = PluginPage.MODEL
plugin_input_type = PluginInputType.MODEL
+ def __init__(self):
+ self.container = None
+
def create_component(self, component_id: str) -> dcc.Graph:
"""Create a Confusion Matrix Component without any data.
Args:
@@ -23,16 +26,15 @@ def create_component(self, component_id: str) -> dcc.Graph:
Returns:
dcc.Graph: The Confusion Matrix Component
"""
- return dcc.Graph(id=component_id, figure=self.waiting_figure())
+ self.container = dcc.Graph(id=component_id, figure=self.waiting_figure())
- def update_contents(self, model: Model) -> go.Figure:
+ def update_contents(self, model: Model) -> list:
"""Create a Confusion Matrix Figure for the numeric columns in the dataframe.
Args:
model (Model): An instantiated Model object
- Returns:
- go.Figure: A Plotly Figure object
"""
- return PluginInterface.display_text("I'm a good plugin...")
+ text_figure = PluginInterface.display_text("I'm a good plugin...")
+ return [text_figure]
class IncorrectMethods(PluginInterface):
@@ -40,7 +42,7 @@ class IncorrectMethods(PluginInterface):
they have create_component but forgot to implement update_contents"""
"""Initialize this Plugin Component Class with required attributes"""
- plugin_page = PluginPage.MODEL
+ auto_load_page = PluginPage.MODEL
plugin_input_type = PluginInputType.MODEL
def create_component(self, component_id: str) -> dcc.Graph:
@@ -57,7 +59,7 @@ class IncorrectArgTypes(PluginInterface):
"""Subclass of PluginInterface with an incorrectly typed argument."""
"""Initialize this Plugin Component Class with required attributes"""
- plugin_page = PluginPage.MODEL
+ auto_load_page = PluginPage.MODEL
plugin_input_type = PluginInputType.MODEL
# Component is an incorrectly named keyword argument
@@ -84,7 +86,7 @@ class IncorrectReturnType(PluginInterface):
"""Subclass of PluginInterface with incorrect return type."""
"""Initialize this Plugin Component Class with required attributes"""
- plugin_page = PluginPage.MODEL
+ auto_load_page = PluginPage.MODEL
plugin_input_type = PluginInputType.MODEL
def create_component(self, component_id: str) -> dcc.Graph:
@@ -96,14 +98,14 @@ def create_component(self, component_id: str) -> dcc.Graph:
"""
return dcc.Graph(id=component_id, figure=self.waiting_figure())
- def update_contents(self, model: Model) -> list:
+ def update_contents(self, model: Model) -> go.Figure:
"""Create a Figure but give the wrong return type.
Args:
model (Model): An instantiated Model object
Returns:
list: An incorrect return type
"""
- return [1, 2, 3] # Incorrect return type
+ return go.Figure() # Incorrect return type
def test_incorrect_methods():