Skip to content

Commit

Permalink
Merge pull request #440 from SuperCowPowers/model_details_refactor
Browse files Browse the repository at this point in the history
Model details refactor
  • Loading branch information
brifordwylie authored Mar 28, 2024
2 parents 890f2a7 + b08a21b commit 8e11b9f
Show file tree
Hide file tree
Showing 12 changed files with 444 additions and 534 deletions.
22 changes: 5 additions & 17 deletions applications/aws_dashboard/pages/endpoints/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

# SageWorks Imports
from sageworks.views.endpoint_web_view import EndpointWebView
from sageworks.web_components import table, model_details_markdown, endpoint_metric_plots
from sageworks.web_components import table, endpoint_metric_plots
from sageworks.utils.pandas_utils import deserialize_aws_broker_data
from sageworks.api.endpoint import Endpoint
from sageworks.api.model import Model


def update_endpoints_table(app: Dash):
Expand Down Expand Up @@ -54,13 +53,9 @@ def style_selected_rows(selected_rows):


# Updates the endpoint details when a endpoint row is selected
def update_endpoint_details_components(app: Dash, endpoint_web_view: EndpointWebView):
def update_endpoint_metrics(app: Dash, endpoint_web_view: EndpointWebView):
@app.callback(
[
Output("endpoint_details_header", "children"),
Output("endpoint_details", "children"),
Output("endpoint_metrics", "figure"),
],
Output("endpoint_metrics", "figure"),
Input("endpoints_table", "derived_viewport_selected_row_ids"),
State("endpoints_table", "data"),
prevent_initial_call=True,
Expand All @@ -75,21 +70,14 @@ def generate_endpoint_details_figures(selected_rows, table_data):
endpoint_uuid = selected_row_data["uuid"]
print(f"Endpoint UUID: {endpoint_uuid}")

# Set the Header Text
header = f"Details: {endpoint_uuid}"

# Endpoint Details
endpoint_details = endpoint_web_view.endpoint_details(endpoint_uuid)

# Model Details Markdown component (Review This)
model = Model(Endpoint(endpoint_uuid).model_name)
endpoint_details_markdown = model_details_markdown.ModelDetailsMarkdown().generate_markdown(model)

# Endpoint Metrics
endpoint_metrics_figure = endpoint_metric_plots.EndpointMetricPlots().generate_figure(endpoint_details)

# Return the details/markdown for these data details
return [header, endpoint_details_markdown, endpoint_metrics_figure]
return endpoint_metrics_figure


# Updates the plugin component when a endpoint row is selected
Expand All @@ -110,7 +98,7 @@ def update_callback(selected_rows, table_data):
endpoint_uuid = selected_row_data["uuid"]

# Instantiate the Endpoint and send it to the plugin
endpoint = Endpoint(endpoint_uuid)
endpoint = Endpoint(endpoint_uuid, legacy=True)

# Instantiate the Endpoint and send it to the plugin
return plugin.generate_figure(endpoint)
10 changes: 6 additions & 4 deletions applications/aws_dashboard/pages/endpoints/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from . import callbacks

# SageWorks Imports
from sageworks.web_components import table, model_details_markdown, endpoint_metric_plots
from sageworks.web_components import table, endpoint_details, endpoint_metric_plots
from sageworks.web_components.plugin_interface import PluginPage
from sageworks.views.endpoint_web_view import EndpointWebView
from sageworks.utils.plugin_manager import PluginManager
Expand All @@ -33,15 +33,16 @@
)

# Create a Markdown component to display the endpoint details
endpoint_details = model_details_markdown.ModelDetailsMarkdown().create_component("endpoint_details")
endpoint_details = endpoint_details.EndpointDetails()
endpoint_details_component = endpoint_details.create_component("endpoint_details")

# Create a component to display the endpoint metrics
endpoint_metrics = endpoint_metric_plots.EndpointMetricPlots().create_component("endpoint_metrics")

# Capture our components in a dictionary to send off to the layout
components = {
"endpoints_table": endpoints_table,
"endpoint_details": endpoint_details,
"endpoint_details": endpoint_details_component,
"endpoint_metrics": endpoint_metrics,
}

Expand All @@ -60,10 +61,11 @@
# Setup our callbacks/connections
app = dash.get_app()
callbacks.update_endpoints_table(app)
endpoint_details.register_callbacks("endpoints_table")

# Callback for the endpoints table
callbacks.table_row_select(app, "endpoints_table")
callbacks.update_endpoint_details_components(app, endpoint_broker)
callbacks.update_endpoint_metrics(app, endpoint_broker)

# For each plugin, set up a callback to update the plugin figure
for plugin in plugins:
Expand Down
97 changes: 5 additions & 92 deletions applications/aws_dashboard/pages/models/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from sageworks.views.model_web_view import ModelWebView
from sageworks.web_components import (
table,
model_details_markdown,
model_metrics_markdown,
model_plot,
)
from sageworks.utils.pandas_utils import deserialize_aws_broker_data
Expand Down Expand Up @@ -54,108 +52,23 @@ def style_selected_rows(selected_rows):
return row_style


# Updates the model details when a model row is selected
def update_model_detail_component(app: Dash):
@app.callback(
[Output("model_details_header", "children"), Output("model_details", "children")],
Input("models_table", "derived_viewport_selected_row_ids"),
State("models_table", "data"),
prevent_initial_call=True,
)
def generate_model_details_figure(selected_rows, table_data):
# Check for no selected rows
if not selected_rows or selected_rows[0] is None:
return no_update

# Get the selected row data and grab the uuid
selected_row_data = table_data[selected_rows[0]]
model_uuid = selected_row_data["uuid"]
m = Model(model_uuid)

# Set the Header Text
header = f"Model: {model_uuid}"

# Model Details Markdown component
model_details_fig = model_details_markdown.ModelDetailsMarkdown().generate_markdown(m)

# Return the details/markdown for these data details
return [header, model_details_fig]


# Updates Inference Run Selector Component
def update_inference_dropdown(app: Dash):
@app.callback(
[Output("inference_dropdown", "options"), Output("inference_dropdown", "value")],
Input("models_table", "derived_viewport_selected_row_ids"),
State("models_table", "data"),
prevent_initial_call=True,
)
def generate_inference_dropdown_figure(selected_rows, table_data):
# Check for no selected rows
if not selected_rows or selected_rows[0] is None:
return no_update

# Get the selected row data and grab the uuid
selected_row_data = table_data[selected_rows[0]]
model_uuid = selected_row_data["uuid"]
m = Model(model_uuid)

# Inference runs
inference_runs = m.list_inference_runs()

# Check if there are any inference runs to select
if not inference_runs:
return [], None

# Set "training_holdout" as the default, if that doesn't exist, set the first
default_inference_run = "training_holdout" if "training_holdout" in inference_runs else inference_runs[0]

# Return the options for the dropdown and the selected value
return inference_runs, default_inference_run


# Updates the model metrics when a model row is selected
def update_model_metrics_component(app: Dash):
@app.callback(
Output("model_metrics", "children"),
[Input("models_table", "derived_viewport_selected_row_ids"), Input("inference_dropdown", "value")],
State("models_table", "data"),
prevent_initial_call=True,
)
def generate_model_metrics_figure(selected_rows, inference_run, table_data):
# Check for no selected rows
if not selected_rows or selected_rows[0] is None:
return no_update

# Get the selected row data and grab the uuid
selected_row_data = table_data[selected_rows[0]]
model_uuid = selected_row_data["uuid"]
m = Model(model_uuid)

# Model Details Markdown component
model_metrics_fig = model_metrics_markdown.ModelMetricsMarkdown().generate_markdown(m, inference_run)

# Return the details/markdown for these data details
return model_metrics_fig


# Updates the model plot when a model row is selected
def update_model_plot_component(app: Dash):
@app.callback(
Output("model_plot", "figure"),
[Input("models_table", "derived_viewport_selected_row_ids"), Input("inference_dropdown", "value")],
State("models_table", "data"),
Input("model_details-dropdown", "value"),
[State("models_table", "data"), State("models_table", "derived_viewport_selected_row_ids")],
prevent_initial_call=True,
)
def generate_model_plot_figure(selected_rows, inference_run, table_data):
def generate_model_plot_figure(inference_run, table_data, selected_rows):
# Check for no selected rows
if not selected_rows or selected_rows[0] is None:
return no_update

# Get the selected row data and grab the uuid
selected_row_data = table_data[selected_rows[0]]
model_uuid = selected_row_data["uuid"]
m = Model(model_uuid)
m = Model(model_uuid, legacy=True)

# Model Details Markdown component
model_plot_fig = model_plot.ModelPlot().generate_figure(m, inference_run)
Expand All @@ -182,5 +95,5 @@ def update_plugin_figure(selected_rows, table_data):
model_uuid = selected_row_data["uuid"]

# Instantiate the Model and send it to the plugin
model = Model(model_uuid)
model = Model(model_uuid, legacy=True)
return plugin.generate_figure(model)
20 changes: 2 additions & 18 deletions applications/aws_dashboard/pages/models/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

def models_layout(
models_table: dash_table.DataTable,
inference_dropdown: dcc.Graph,
model_details: dcc.Markdown,
model_metrics: dcc.Markdown,
model_details: html.Div,
model_plot: dcc.Graph,
**kwargs: Any,
) -> html.Div:
Expand All @@ -35,21 +33,7 @@ def models_layout(
dbc.Row(
[
# Column 1: Model Details
dbc.Col(
[
dbc.Row(
[html.H3("Model: Loading...", id="model_details_header"),
model_metrics,
inference_dropdown],
style={"padding": "30px 0px 40px 0px"},
),
dbc.Row(
[html.H3("Model Details"), model_details],
style={"padding": "0px 0px 30px 0px"},
),
],
width=4,
),
dbc.Col(model_details, width=4),
# Column 2: Model Plot and Plugins
dbc.Col(
[
Expand Down
21 changes: 5 additions & 16 deletions applications/aws_dashboard/pages/models/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dash import register_page
import dash
from dash_bootstrap_templates import load_figure_template
from dash import dcc

# Local Imports
from .layout import models_layout
Expand All @@ -12,8 +11,7 @@
# SageWorks Imports
from sageworks.web_components import (
table,
model_details_markdown,
model_metrics_markdown,
model_details,
model_plot,
)
from sageworks.web_components.plugin_interface import PluginPage
Expand All @@ -39,23 +37,16 @@
)

# Create a Markdown component to display the model details
model_details = model_details_markdown.ModelDetailsMarkdown().create_component("model_details")

# Create a Inference Run Dropdown component
inference_dropdown = dcc.Dropdown(id="inference_dropdown", className="dropdown")

# Create a Markdown component to display model metrics
model_metrics = model_metrics_markdown.ModelMetricsMarkdown().create_component("model_metrics")
model_details = model_details.ModelDetails()
model_details_component = model_details.create_component("model_details")

# Create a Model Plot component to display the model metrics
model_plot_component = model_plot.ModelPlot().create_component("model_plot")

# Capture our components in a dictionary to send off to the layout
components = {
"models_table": models_table,
"inference_dropdown": inference_dropdown,
"model_details": model_details,
"model_metrics": model_metrics,
"model_details": model_details_component,
"model_plot": model_plot_component,
}

Expand All @@ -74,12 +65,10 @@
# Setup our callbacks/connections
app = dash.get_app()
callbacks.update_models_table(app)
model_details.register_callbacks("models_table")

# Callback for the model table
callbacks.table_row_select(app, "models_table")
callbacks.update_inference_dropdown(app)
callbacks.update_model_detail_component(app)
callbacks.update_model_metrics_component(app)
callbacks.update_model_plot_component(app)

# For each plugin, set up a callback to update the plugin figure
Expand Down
62 changes: 0 additions & 62 deletions src/sageworks/web_components/deprecated/plugin_loader.py

This file was deleted.

Loading

0 comments on commit 8e11b9f

Please sign in to comment.