From c60f69e7cb6cbfa9fa22c17dfbba704321ee8c72 Mon Sep 17 00:00:00 2001 From: Brian Wylie Date: Wed, 27 Mar 2024 12:53:42 -0600 Subject: [PATCH 1/6] removing some deprecated code --- .../deprecated/plugin_loader.py | 62 ------------------- 1 file changed, 62 deletions(-) delete mode 100644 src/sageworks/web_components/deprecated/plugin_loader.py diff --git a/src/sageworks/web_components/deprecated/plugin_loader.py b/src/sageworks/web_components/deprecated/plugin_loader.py deleted file mode 100644 index 70990f37c..000000000 --- a/src/sageworks/web_components/deprecated/plugin_loader.py +++ /dev/null @@ -1,62 +0,0 @@ -import os -import importlib.util -from typing import List -import logging -import inspect - -# SageWorks imports -from sageworks.web_components.plugin_interface import PluginInterface, PluginPage - - -# SageWorks Logger -log = logging.getLogger("sageworks") - - -def load_plugins_from_dir(directory: str, plugin_page: PluginPage) -> List[PluginInterface]: - """Load all the plugins from the given directory. - Args: - directory (str): The directory to load the plugins from. - plugin_page (PluginPage): The type of plugin to load. - Returns: - List[PluginInterface]: A list of plugins that were loaded. - """ - - if not os.path.isdir(directory): - log.warning(f"Directory {directory} does not exist. No plugins loaded.") - return [] - - plugins = [] - for filename in os.listdir(directory): - if filename.endswith(".py"): - file_path = os.path.join(directory, filename) - spec = importlib.util.spec_from_file_location(filename, file_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - for _, attribute in inspect.getmembers(module, inspect.isclass): - if attribute.__module__ == module.__name__: - if issubclass(attribute, PluginInterface) and attribute is not PluginInterface: - try: - instance = attribute() - if instance.plugin_page == plugin_page: - plugins.append(instance) - except TypeError as e: - log.error(f"Error initializing plugin from {filename}: {e}") - else: - log.warning(f"Class {attribute.__name__} in {filename} invalid PluginInterface subclass") - - return plugins - - -if __name__ == "__main__": - # Example of loading plugins from a directory - from sageworks.utils.config_manager import ConfigManager - - # Get the plugin directory from the environment variable - plugin_dir = ConfigManager().get_config("SAGEWORKS_PLUGINS") - - # Loop through the various plugin types and load the plugins - plugin_pages = [PluginPage.DATA_SOURCE, PluginPage.FEATURE_SET, PluginPage.MODEL, PluginPage.ENDPOINT] - for plugin_page in plugin_pages: - plugins = load_plugins_from_dir(plugin_dir, plugin_page) - for plugin in plugins: - log.info(f"Loaded {plugin_page} plugin: {plugin.__class__.__name__}") From 75319a6d47e098a3b7a6439f51225eebfb453c7e Mon Sep 17 00:00:00 2001 From: Brian Wylie Date: Thu, 28 Mar 2024 08:58:13 -0600 Subject: [PATCH 2/6] adding a new composite class for model details that encapsulates several related components and manages internal callbacks --- .../aws_dashboard/pages/models/callbacks.py | 95 +------ .../aws_dashboard/pages/models/layout.py | 21 +- .../aws_dashboard/pages/models/page.py | 20 +- .../web_components/inference_run_selector.py | 68 ----- src/sageworks/web_components/model_details.py | 254 ++++++++++++++++++ .../web_components/model_details_markdown.py | 134 --------- .../web_components/model_metrics_markdown.py | 95 ------- 7 files changed, 267 insertions(+), 420 deletions(-) delete mode 100644 src/sageworks/web_components/inference_run_selector.py create mode 100644 src/sageworks/web_components/model_details.py delete mode 100644 src/sageworks/web_components/model_details_markdown.py delete mode 100644 src/sageworks/web_components/model_metrics_markdown.py diff --git a/applications/aws_dashboard/pages/models/callbacks.py b/applications/aws_dashboard/pages/models/callbacks.py index f0a93f171..e8a7cdc63 100644 --- a/applications/aws_dashboard/pages/models/callbacks.py +++ b/applications/aws_dashboard/pages/models/callbacks.py @@ -7,8 +7,7 @@ from sageworks.views.model_web_view import ModelWebView from sageworks.web_components import ( table, - model_details_markdown, - model_metrics_markdown, + model_details, model_plot, ) from sageworks.utils.pandas_utils import deserialize_aws_broker_data @@ -54,100 +53,16 @@ def style_selected_rows(selected_rows): return row_style -# Updates the model details when a model row is selected -def update_model_detail_component(app: Dash): - @app.callback( - [Output("model_details_header", "children"), Output("model_details", "children")], - Input("models_table", "derived_viewport_selected_row_ids"), - State("models_table", "data"), - prevent_initial_call=True, - ) - def generate_model_details_figure(selected_rows, table_data): - # Check for no selected rows - if not selected_rows or selected_rows[0] is None: - return no_update - - # Get the selected row data and grab the uuid - selected_row_data = table_data[selected_rows[0]] - model_uuid = selected_row_data["uuid"] - m = Model(model_uuid) - - # Set the Header Text - header = f"Model: {model_uuid}" - - # Model Details Markdown component - model_details_fig = model_details_markdown.ModelDetailsMarkdown().generate_markdown(m) - - # Return the details/markdown for these data details - return [header, model_details_fig] - - -# Updates Inference Run Selector Component -def update_inference_dropdown(app: Dash): - @app.callback( - [Output("inference_dropdown", "options"), Output("inference_dropdown", "value")], - Input("models_table", "derived_viewport_selected_row_ids"), - State("models_table", "data"), - prevent_initial_call=True, - ) - def generate_inference_dropdown_figure(selected_rows, table_data): - # Check for no selected rows - if not selected_rows or selected_rows[0] is None: - return no_update - - # Get the selected row data and grab the uuid - selected_row_data = table_data[selected_rows[0]] - model_uuid = selected_row_data["uuid"] - m = Model(model_uuid) - - # Inference runs - inference_runs = m.list_inference_runs() - - # Check if there are any inference runs to select - if not inference_runs: - return [], None - - # Set "training_holdout" as the default, if that doesn't exist, set the first - default_inference_run = "training_holdout" if "training_holdout" in inference_runs else inference_runs[0] - - # Return the options for the dropdown and the selected value - return inference_runs, default_inference_run - - -# Updates the model metrics when a model row is selected -def update_model_metrics_component(app: Dash): - @app.callback( - Output("model_metrics", "children"), - [Input("models_table", "derived_viewport_selected_row_ids"), Input("inference_dropdown", "value")], - State("models_table", "data"), - prevent_initial_call=True, - ) - def generate_model_metrics_figure(selected_rows, inference_run, table_data): - # Check for no selected rows - if not selected_rows or selected_rows[0] is None: - return no_update - - # Get the selected row data and grab the uuid - selected_row_data = table_data[selected_rows[0]] - model_uuid = selected_row_data["uuid"] - m = Model(model_uuid) - - # Model Details Markdown component - model_metrics_fig = model_metrics_markdown.ModelMetricsMarkdown().generate_markdown(m, inference_run) - - # Return the details/markdown for these data details - return model_metrics_fig - - # Updates the model plot when a model row is selected def update_model_plot_component(app: Dash): @app.callback( Output("model_plot", "figure"), - [Input("models_table", "derived_viewport_selected_row_ids"), Input("inference_dropdown", "value")], - State("models_table", "data"), + Input("model_details-dropdown", "value"), + [State("models_table", "data"), + State("models_table", "derived_viewport_selected_row_ids")], prevent_initial_call=True, ) - def generate_model_plot_figure(selected_rows, inference_run, table_data): + def generate_model_plot_figure(inference_run, table_data, selected_rows): # Check for no selected rows if not selected_rows or selected_rows[0] is None: return no_update diff --git a/applications/aws_dashboard/pages/models/layout.py b/applications/aws_dashboard/pages/models/layout.py index 88dadbeb1..28316879a 100644 --- a/applications/aws_dashboard/pages/models/layout.py +++ b/applications/aws_dashboard/pages/models/layout.py @@ -7,9 +7,7 @@ def models_layout( models_table: dash_table.DataTable, - inference_dropdown: dcc.Graph, - model_details: dcc.Markdown, - model_metrics: dcc.Markdown, + model_details: html.Div, model_plot: dcc.Graph, **kwargs: Any, ) -> html.Div: @@ -35,21 +33,8 @@ def models_layout( dbc.Row( [ # Column 1: Model Details - dbc.Col( - [ - dbc.Row( - [html.H3("Model: Loading...", id="model_details_header"), - model_metrics, - inference_dropdown], - style={"padding": "30px 0px 40px 0px"}, - ), - dbc.Row( - [html.H3("Model Details"), model_details], - style={"padding": "0px 0px 30px 0px"}, - ), - ], - width=4, - ), + dbc.Col(model_details, width=4), + # Column 2: Model Plot and Plugins dbc.Col( [ diff --git a/applications/aws_dashboard/pages/models/page.py b/applications/aws_dashboard/pages/models/page.py index 4d4bf33de..666c0a1eb 100644 --- a/applications/aws_dashboard/pages/models/page.py +++ b/applications/aws_dashboard/pages/models/page.py @@ -12,8 +12,7 @@ # SageWorks Imports from sageworks.web_components import ( table, - model_details_markdown, - model_metrics_markdown, + model_details, model_plot, ) from sageworks.web_components.plugin_interface import PluginPage @@ -39,13 +38,8 @@ ) # Create a Markdown component to display the model details -model_details = model_details_markdown.ModelDetailsMarkdown().create_component("model_details") - -# Create a Inference Run Dropdown component -inference_dropdown = dcc.Dropdown(id="inference_dropdown", className="dropdown") - -# Create a Markdown component to display model metrics -model_metrics = model_metrics_markdown.ModelMetricsMarkdown().create_component("model_metrics") +model_details = model_details.ModelDetails() +model_details_component = model_details.create_component("model_details") # Create a Model Plot component to display the model metrics model_plot_component = model_plot.ModelPlot().create_component("model_plot") @@ -53,9 +47,7 @@ # Capture our components in a dictionary to send off to the layout components = { "models_table": models_table, - "inference_dropdown": inference_dropdown, - "model_details": model_details, - "model_metrics": model_metrics, + "model_details": model_details_component, "model_plot": model_plot_component, } @@ -74,12 +66,10 @@ # Setup our callbacks/connections app = dash.get_app() callbacks.update_models_table(app) +model_details.register_callbacks("models_table") # Callback for the model table callbacks.table_row_select(app, "models_table") -callbacks.update_inference_dropdown(app) -callbacks.update_model_detail_component(app) -callbacks.update_model_metrics_component(app) callbacks.update_model_plot_component(app) # For each plugin, set up a callback to update the plugin figure diff --git a/src/sageworks/web_components/inference_run_selector.py b/src/sageworks/web_components/inference_run_selector.py deleted file mode 100644 index c88dd258e..000000000 --- a/src/sageworks/web_components/inference_run_selector.py +++ /dev/null @@ -1,68 +0,0 @@ -"""An Inference Selector Component for models""" - -from dash import dcc - -# SageWorks Imports -from sageworks.api import Model -from sageworks.web_components.component_interface import ComponentInterface - - -class InferenceRunSelector(ComponentInterface): - """Inference Run Selector Component""" - - def create_component(self, component_id: str) -> dcc.Graph: - """Create a Dropdown Component without any data. - Args: - component_id (str): The ID of the web component - Returns: - dcc.Dropdown: A Dropdown component - """ - return dcc.Dropdown(id=component_id) - - def generate_inference_runs(self, model: Model) -> list[str]: - """Generates the inference runs to be used as options for the Dropdown - Args: - model (Model): Sageworks Model object - Returns: - list[str]: A list of inference runs - """ - - # Inference runs - inference_runs = model.list_inference_runs() - - return inference_runs - - -if __name__ == "__main__": - # This class takes in model details and generates a Confusion Matrix - import dash - from dash import dcc, html, Input, Output, callback - import dash_bootstrap_components as dbc - from sageworks.api import Model - - # Instantiate model - m = Model("abalone-regression") - - # Instantiate the class - irs = InferenceRunSelector() - - # Generate the component - dropdown = irs.create_component("dropdown") - inf_runs = irs.generate_inference_runs(m) - - # Initialize Dash app - app = dash.Dash( - __name__, - external_stylesheets=[dbc.themes.DARKLY], - assets_folder="/home/kolmar/sageworks/applications/aws_dashboard/assets", - ) - - app.layout = html.Div([dropdown, html.Div(id="dd-output-container")]) - dropdown.options = inf_runs - - @callback(Output("dd-output-container", "children"), Input("dropdown", "value")) - def update_output(value): - return f"You have selected {value}" - - # Run server - app.run_server(debug=True) diff --git a/src/sageworks/web_components/model_details.py b/src/sageworks/web_components/model_details.py new file mode 100644 index 000000000..2675ccc7b --- /dev/null +++ b/src/sageworks/web_components/model_details.py @@ -0,0 +1,254 @@ +"""A Markdown Component for details/information about Models""" + +# Dash Imports +from dash import html, callback, no_update, dcc +from dash.dependencies import Input, Output, State + +# SageWorks Imports +from sageworks.api import Model +from sageworks.web_components.component_interface import ComponentInterface +from sageworks.utils.symbols import health_icons + + +class ModelDetails(ComponentInterface): + """Model Markdown Component""" + def __init__(self): + self.prefix_id = "" + self.model = None + super().__init__() + + def create_component(self, component_id: str) -> html.Div: + """Create a Markdown Component without any data. + Args: + component_id (str): The ID of the web component + Returns: + html.Div: A Container of Components for the Model Details + """ + self.prefix_id = component_id + container = html.Div( + id=self.prefix_id, + children=[ + html.H3(id=f"{self.prefix_id}-header", children="Model: Loading..."), + dcc.Markdown(id=f"{self.prefix_id}-summary"), + html.H3(children="Inference Metrics"), + dcc.Dropdown(id=f"{self.prefix_id}-dropdown", className="dropdown"), + dcc.Markdown(id=f"{self.prefix_id}-metrics"), + ], + ) + return container + + def register_callbacks(self, model_table): + @callback( + [ + Output(f"{self.prefix_id}-header", "children"), + Output(f"{self.prefix_id}-summary", "children"), + Output(f"{self.prefix_id}-dropdown", "options"), + Output(f"{self.prefix_id}-dropdown", "value") + ], + Input(model_table, "derived_viewport_selected_row_ids"), + State(model_table, "data"), + ) + def update_model(selected_rows, table_data): + # Check for no selected rows + if not selected_rows or selected_rows[0] is None: + return no_update + + # Get the selected row data, grab the uuid, and set the Model object + selected_row_data = table_data[selected_rows[0]] + model_uuid = selected_row_data["uuid"] + self.model = Model(model_uuid) + + # Update the header, the summary, and the details + header = f"Model: {self.model.uuid}" + summary = self.model_summary() + + # Populate the inference runs dropdown + inference_runs, default_run = self.get_inference_runs() + + return header, summary, inference_runs, default_run + + @callback( + Output(f"{self.prefix_id}-metrics", "children"), + Input(f"{self.prefix_id}-dropdown", "value"), + ) + def update_inference_run(inference_run): + # Check for no inference run + if not inference_run: + return no_update + + # Update the model metrics + metrics = self.inference_metrics(inference_run) + + return metrics + + def model_summary(self): + """Construct the markdown string for the model summary + + Returns: + str: A markdown string + """ + # Get these fields from the model + # Get these fields from the model + show_fields = ["health_tags", "input", "sageworks_registered_endpoints", + "sageworks_model_type", "sageworks_tags", "sageworks_model_target", + "sageworks_model_features"] + + # Construct the markdown string + summary = self.model.summary() + markdown = "" + for key in show_fields: + + # Special case for the health tags + if key == "health_tags": + markdown += self._health_tag_markdown(summary.get(key, [])) + continue + + # Special case for the features + if key == "sageworks_model_features": + value = summary.get(key, []) + key = "features" + value = f"({len(value)}) {', '.join(value)[:100]}..." + markdown += f"**{key}:** {value} \n" + continue + + # Get the value + value = summary.get(key, "-") + + # If the value is a list, convert it to a comma-separated string + if isinstance(value, list): + value = ', '.join(value) + + # Chop off the "sageworks_" prefix + key = key.replace("sageworks_", "") + + # Add to markdown string + markdown += f"**{key}:** {value} \n" + + return markdown + + def inference_metrics(self, inference_run: str): + """Construct the markdown string for the model metrics + + Args: + inference_run (str): The inference run to get the metrics for + Returns: + str: A markdown string + """ + # Model Metrics + meta_df = self.model.inference_metadata(inference_run) + if meta_df is None: + test_data = "Inference Metadata Not Found" + test_data_hash = " N/A " + test_rows = " - " + description = " - " + else: + inference_meta = meta_df.to_dict(orient="records")[0] + test_data = inference_meta.get("name", " - ") + test_data_hash = inference_meta.get("data_hash", " - ") + test_rows = inference_meta.get("num_rows", " - ") + description = inference_meta.get("description", " - ") + + # Add the markdown for the model test metrics + markdown = "\n" + markdown += f"**Test Data:** {test_data} \n" + markdown += f"**Data Hash:** {test_data_hash} \n" + markdown += f"**Test Rows:** {test_rows} \n" + markdown += f"**Description:** {description} \n" + + # Grab the Metrics from the model details + metrics = self.model.performance_metrics(capture_uuid=inference_run) + if metrics is None: + markdown += " \nNo Data \n" + else: + markdown += " \n" + metrics = metrics.round(3) + markdown += metrics.to_markdown(index=False) + + print(markdown) + return markdown + + def get_inference_runs(self): + """Get the inference runs for the model + + Returns: + list[str]: A list of inference runs + default_run (str): The default inference run + """ + + # Inference runs + inference_runs = self.model.list_inference_runs() + + # Check if there are any inference runs to select + if not inference_runs: + return [], None + + # Set "training_holdout" as the default, if that doesn't exist, set the first + default_inference_run = "training_holdout" if "training_holdout" in inference_runs else inference_runs[0] + + # Return the options for the dropdown and the selected value + return inference_runs, default_inference_run + + @staticmethod + def _health_tag_markdown(health_tags: list[str]) -> str: + """Internal method to generate the health tag markdown + Args: + health_tags (list[str]): A list of health tags + Returns: + str: A markdown string + """ + # If we have no health tags, then add a bullet for healthy + markdown = "**Health Checks**\n" # Header for Health Checks + + # If we have no health tags, then add a bullet for healthy + if not health_tags: + markdown += f"* Healthy: {health_icons.get('healthy')}\n\n" + return markdown + + # Special case for no_activity with no other tags + if len(health_tags) == 1 and health_tags[0] == "no_activity": + markdown += f"* Healthy: {health_icons.get('healthy')}\n" + markdown += f"* No Activity: {health_icons.get('no_activity')}\n\n" + return markdown + + # If we have health tags, then add a bullet for each tag + markdown += "\n".join(f"* {tag}: {health_icons.get(tag, '')}" for tag in health_tags) + markdown += "\n\n" # Add newlines for separation + return markdown + + +if __name__ == "__main__": + # This class takes in model details and generates a details Markdown component + import dash + import dash_bootstrap_components as dbc + from sageworks.web_components.table import Table + from sageworks.views.artifacts_web_view import ArtifactsWebView + + # Create a model table + models_table = Table().create_component( + "models_table", header_color="rgb(60, 100, 60)", row_select="single", max_height=270 + ) + + # Populate the table with data + view = ArtifactsWebView() + models = view.models_summary() + models["id"] = range(len(models)) + column_setup_list = Table().column_setup(models, markdown_columns=["Model Group"]) + models_table.columns = column_setup_list + models_table.data = models.to_dict("records") + + # Instantiate the ModelDetails class + md = ModelDetails() + details_component = md.create_component("model_details") + + # Register the callbacks + md.register_callbacks("models_table") + + # Initialize Dash app + app = dash.Dash( + __name__, + external_stylesheets=[dbc.themes.DARKLY], + assets_folder="/Users/briford/work/sageworks/applications/aws_dashboard/assets", + ) + + app.layout = html.Div([models_table, details_component]) + app.run_server(debug=True) diff --git a/src/sageworks/web_components/model_details_markdown.py b/src/sageworks/web_components/model_details_markdown.py deleted file mode 100644 index 879991241..000000000 --- a/src/sageworks/web_components/model_details_markdown.py +++ /dev/null @@ -1,134 +0,0 @@ -"""A Markdown Component for details/information about Models""" - -import pandas as pd -from dash import dcc - -# SageWorks Imports -from sageworks.api import Model -from sageworks.web_components.component_interface import ComponentInterface -from sageworks.utils.symbols import health_icons - - -class ModelDetailsMarkdown(ComponentInterface): - """Model Markdown Component""" - - def create_component(self, component_id: str) -> dcc.Markdown: - """Create a Markdown Component without any data. - Args: - component_id (str): The ID of the web component - Returns: - dcc.Markdown: The Dash Markdown Component - """ - waiting_markdown = "*Waiting for data...*" - return dcc.Markdown(id=component_id, children=waiting_markdown, dangerously_allow_html=False) - - def generate_markdown(self, model: Model) -> str: - """Create the Markdown for the details/information about the DataSource or the FeatureSet - Args: - model (Model): Sageworks Model object - Returns: - str: A Markdown string - """ - - # Get model details - model_details = model.details() - - # If the model details are empty then return a message - if model_details is None: - return "*No Data*" - - # Create simple markdown by iterating through the model_details dictionary - - # Excluded keys from the model_details dictionary (and any keys that end with '_arn') - exclude = ["size", "uuid", "inference_meta", "model_info"] - top_level_details = { - key: value for key, value in model_details.items() if key not in exclude and not key.endswith("_arn") - } - - # FIXME: Remove this later: Add the model info to the top level details - model_info = model_details.get("model_info", {}) - prefixed_model_info = {f"model_{k}": v for k, v in model_info.items()} - top_level_details.update(prefixed_model_info) - - # Construct the markdown string - markdown = "" - for key, value in top_level_details.items(): - # Special case for the health tags - if key == "health_tags": - markdown += self._health_tag_markdown(value) - continue - - # Special case for dataframes - if isinstance(value, pd.DataFrame): - value_str = "Dataframe" - - else: - # Not sure why str() conversion might fail, but we'll catch it - try: - value_str = str(value)[:100] - except Exception as e: - self.log.error(f"Error converting {key} to string: {e}") - value_str = "*" - - # Add to markdown string - markdown += f"**{key}:** {value_str} \n" - - return markdown - - @staticmethod - def _health_tag_markdown(health_tags: list[str]) -> str: - """Internal method to generate the health tag markdown - Args: - health_tags (list[str]): A list of health tags - Returns: - str: A markdown string - """ - # If we have no health tags, then add a bullet for healthy - markdown = "**Health Checks**\n" # Header for Health Checks - - # If we have no health tags, then add a bullet for healthy - if not health_tags: - markdown += f"* Healthy: {health_icons.get('healthy')}\n\n" - return markdown - - # Special case for no_activity with no other tags - if len(health_tags) == 1 and health_tags[0] == "no_activity": - markdown += f"* Healthy: {health_icons.get('healthy')}\n" - markdown += f"* No Activity: {health_icons.get('no_activity')}\n\n" - return markdown - - # If we have health tags, then add a bullet for each tag - markdown += "\n".join(f"* {tag}: {health_icons.get(tag, '')}" for tag in health_tags) - markdown += "\n\n" # Add newlines for separation - return markdown - - -if __name__ == "__main__": - # This class takes in model details and generates a details Markdown component - import dash - from dash import dcc, html - import dash_bootstrap_components as dbc - from sageworks.api import Model - - # Create the class and get the AWS FeatureSet details - m = Model("wine-classification") - - # Instantiate the DataDetailsMarkdown class - ddm = ModelDetailsMarkdown() - component = ddm.create_component("model_details_markdown") - - # Generate the markdown - markdown = ddm.generate_markdown(m) - - # Initialize Dash app - app = dash.Dash( - __name__, - external_stylesheets=[dbc.themes.DARKLY], - assets_folder="", - ) - - app.layout = html.Div([component]) - component.children = markdown - - if __name__ == "__main__": - app.run_server(debug=True) diff --git a/src/sageworks/web_components/model_metrics_markdown.py b/src/sageworks/web_components/model_metrics_markdown.py deleted file mode 100644 index b350ed1e7..000000000 --- a/src/sageworks/web_components/model_metrics_markdown.py +++ /dev/null @@ -1,95 +0,0 @@ -"""A Markdown Component for model metrics""" - -from dash import dcc - -# SageWorks Imports -from sageworks.api import Model -from sageworks.web_components.component_interface import ComponentInterface - - -class ModelMetricsMarkdown(ComponentInterface): - """Model Markdown Component""" - - def create_component(self, component_id: str) -> dcc.Markdown: - """Create a Markdown Component without any data. - Args: - component_id (str): The ID of the web component - Returns: - dcc.Markdown: The Dash Markdown Component - """ - waiting_markdown = "*Waiting for data...*" - return dcc.Markdown(id=component_id, children=waiting_markdown, dangerously_allow_html=False) - - def generate_markdown(self, model: Model, inference_run: str) -> str: - """Create the Markdown for the details/information about the DataSource or the FeatureSet - Args: - model (Model): Sageworks Model object - inference_run (str): Valid capture_uuid - Returns: - str: A Markdown string - """ - - # Model Metrics - markdown = "### Model Metrics \n" - meta_df = model.inference_metadata(inference_run) - if meta_df is None: - test_data = "Inference Metadata Not Found" - test_data_hash = " N/A " - test_rows = " - " - description = " - " - else: - inference_meta = meta_df.to_dict(orient="records")[0] - test_data = inference_meta.get("name", " - ") - test_data_hash = inference_meta.get("data_hash", " - ") - test_rows = inference_meta.get("num_rows", " - ") - description = inference_meta.get("description", " - ") - - # Add the markdown for the model test metrics - markdown += f"**Test Data:** {test_data} \n" - markdown += f"**Data Hash:** {test_data_hash} \n" - markdown += f"**Test Rows:** {test_rows} \n" - markdown += f"**Description:** {description} \n" - - # Grab the Metrics from the model details - metrics = model.performance_metrics(capture_uuid=inference_run) - if metrics is None: - markdown += " \nNo Data \n" - else: - markdown += " \n" - metrics = metrics.round(3) - markdown += metrics.to_markdown(index=False) - - return markdown - - -if __name__ == "__main__": - # This class takes in model metrics and generates a Markdown Component - import dash - from dash import dcc, html - import dash_bootstrap_components as dbc - from sageworks.api import Model - - # Create the class and get the AWS FeatureSet details - m = Model("wine-classification") - inference_run = "model_training" - - # Instantiate the DataDetailsMarkdown class - ddm = ModelMetricsMarkdown() - component = ddm.create_component("model_metrics_markdown") - - # Generate the markdown - markdown = ddm.generate_markdown(m, inference_run) - - print(markdown) - - # Show the Markdown in the Web Browser - app = dash.Dash( - __name__, - external_stylesheets=[dbc.themes.DARKLY], - ) - - app.layout = html.Div([component]) - component.children = markdown - - if __name__ == "__main__": - app.run_server(debug=True) From ae045d9ff3ace6319340de7d250c931f42412822 Mon Sep 17 00:00:00 2001 From: Brian Wylie Date: Thu, 28 Mar 2024 09:49:57 -0600 Subject: [PATCH 3/6] linter/flake8 fixes --- .../aws_dashboard/pages/models/callbacks.py | 4 +-- .../aws_dashboard/pages/models/layout.py | 1 - .../aws_dashboard/pages/models/page.py | 1 - src/sageworks/web_components/model_details.py | 19 +++++++++---- .../inference_run_selector_test.py | 28 ------------------- 5 files changed, 14 insertions(+), 39 deletions(-) delete mode 100644 tests/web_components/inference_run_selector_test.py diff --git a/applications/aws_dashboard/pages/models/callbacks.py b/applications/aws_dashboard/pages/models/callbacks.py index e8a7cdc63..ce9b7ad6a 100644 --- a/applications/aws_dashboard/pages/models/callbacks.py +++ b/applications/aws_dashboard/pages/models/callbacks.py @@ -7,7 +7,6 @@ from sageworks.views.model_web_view import ModelWebView from sageworks.web_components import ( table, - model_details, model_plot, ) from sageworks.utils.pandas_utils import deserialize_aws_broker_data @@ -58,8 +57,7 @@ def update_model_plot_component(app: Dash): @app.callback( Output("model_plot", "figure"), Input("model_details-dropdown", "value"), - [State("models_table", "data"), - State("models_table", "derived_viewport_selected_row_ids")], + [State("models_table", "data"), State("models_table", "derived_viewport_selected_row_ids")], prevent_initial_call=True, ) def generate_model_plot_figure(inference_run, table_data, selected_rows): diff --git a/applications/aws_dashboard/pages/models/layout.py b/applications/aws_dashboard/pages/models/layout.py index 28316879a..68a4ddd0d 100644 --- a/applications/aws_dashboard/pages/models/layout.py +++ b/applications/aws_dashboard/pages/models/layout.py @@ -34,7 +34,6 @@ def models_layout( [ # Column 1: Model Details dbc.Col(model_details, width=4), - # Column 2: Model Plot and Plugins dbc.Col( [ diff --git a/applications/aws_dashboard/pages/models/page.py b/applications/aws_dashboard/pages/models/page.py index 666c0a1eb..e0654c459 100644 --- a/applications/aws_dashboard/pages/models/page.py +++ b/applications/aws_dashboard/pages/models/page.py @@ -3,7 +3,6 @@ from dash import register_page import dash from dash_bootstrap_templates import load_figure_template -from dash import dcc # Local Imports from .layout import models_layout diff --git a/src/sageworks/web_components/model_details.py b/src/sageworks/web_components/model_details.py index 2675ccc7b..0bae0e712 100644 --- a/src/sageworks/web_components/model_details.py +++ b/src/sageworks/web_components/model_details.py @@ -12,6 +12,7 @@ class ModelDetails(ComponentInterface): """Model Markdown Component""" + def __init__(self): self.prefix_id = "" self.model = None @@ -43,7 +44,7 @@ def register_callbacks(self, model_table): Output(f"{self.prefix_id}-header", "children"), Output(f"{self.prefix_id}-summary", "children"), Output(f"{self.prefix_id}-dropdown", "options"), - Output(f"{self.prefix_id}-dropdown", "value") + Output(f"{self.prefix_id}-dropdown", "value"), ], Input(model_table, "derived_viewport_selected_row_ids"), State(model_table, "data"), @@ -83,15 +84,21 @@ def update_inference_run(inference_run): def model_summary(self): """Construct the markdown string for the model summary - + Returns: str: A markdown string """ # Get these fields from the model # Get these fields from the model - show_fields = ["health_tags", "input", "sageworks_registered_endpoints", - "sageworks_model_type", "sageworks_tags", "sageworks_model_target", - "sageworks_model_features"] + show_fields = [ + "health_tags", + "input", + "sageworks_registered_endpoints", + "sageworks_model_type", + "sageworks_tags", + "sageworks_model_target", + "sageworks_model_features", + ] # Construct the markdown string summary = self.model.summary() @@ -116,7 +123,7 @@ def model_summary(self): # If the value is a list, convert it to a comma-separated string if isinstance(value, list): - value = ', '.join(value) + value = ", ".join(value) # Chop off the "sageworks_" prefix key = key.replace("sageworks_", "") diff --git a/tests/web_components/inference_run_selector_test.py b/tests/web_components/inference_run_selector_test.py deleted file mode 100644 index 5f5f1160e..000000000 --- a/tests/web_components/inference_run_selector_test.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Tests for inference run selector web component""" - -# SageWorks Imports -from sageworks.web_components.inference_run_selector import InferenceRunSelector -from sageworks.api.model import Model - - -def test_inference_dropdown(): - """Test the ConfusionMatrix class""" - # Instantiate model - m = Model("wine-classification") - inference_run = "training_holdout" - - # Instantiate the ConfusionMatrix class - irs = InferenceRunSelector() - - # Create the component - dropdown = irs.create_component("dropdown") - - # TBD - print(m) - print(inference_run) - print(dropdown) - - -if __name__ == "__main__": - # Run the tests - test_inference_dropdown() From 8225dd5727653857c9f47d068183383782f7a81f Mon Sep 17 00:00:00 2001 From: Brian Wylie Date: Thu, 28 Mar 2024 10:32:42 -0600 Subject: [PATCH 4/6] adding the endpoint_details web component --- .../pages/endpoints/callbacks.py | 19 +-- .../aws_dashboard/pages/endpoints/page.py | 10 +- .../web_components/endpoint_details.py | 160 ++++++++++++++++++ 3 files changed, 170 insertions(+), 19 deletions(-) create mode 100644 src/sageworks/web_components/endpoint_details.py diff --git a/applications/aws_dashboard/pages/endpoints/callbacks.py b/applications/aws_dashboard/pages/endpoints/callbacks.py index 5b2c7a488..db2780183 100644 --- a/applications/aws_dashboard/pages/endpoints/callbacks.py +++ b/applications/aws_dashboard/pages/endpoints/callbacks.py @@ -5,7 +5,7 @@ # SageWorks Imports from sageworks.views.endpoint_web_view import EndpointWebView -from sageworks.web_components import table, model_details_markdown, endpoint_metric_plots +from sageworks.web_components import table, model_details, endpoint_metric_plots from sageworks.utils.pandas_utils import deserialize_aws_broker_data from sageworks.api.endpoint import Endpoint from sageworks.api.model import Model @@ -54,13 +54,9 @@ def style_selected_rows(selected_rows): # Updates the endpoint details when a endpoint row is selected -def update_endpoint_details_components(app: Dash, endpoint_web_view: EndpointWebView): +def update_endpoint_metrics(app: Dash, endpoint_web_view: EndpointWebView): @app.callback( - [ - Output("endpoint_details_header", "children"), - Output("endpoint_details", "children"), - Output("endpoint_metrics", "figure"), - ], + Output("endpoint_metrics", "figure"), Input("endpoints_table", "derived_viewport_selected_row_ids"), State("endpoints_table", "data"), prevent_initial_call=True, @@ -75,21 +71,14 @@ def generate_endpoint_details_figures(selected_rows, table_data): endpoint_uuid = selected_row_data["uuid"] print(f"Endpoint UUID: {endpoint_uuid}") - # Set the Header Text - header = f"Details: {endpoint_uuid}" - # Endpoint Details endpoint_details = endpoint_web_view.endpoint_details(endpoint_uuid) - # Model Details Markdown component (Review This) - model = Model(Endpoint(endpoint_uuid).model_name) - endpoint_details_markdown = model_details_markdown.ModelDetailsMarkdown().generate_markdown(model) - # Endpoint Metrics endpoint_metrics_figure = endpoint_metric_plots.EndpointMetricPlots().generate_figure(endpoint_details) # Return the details/markdown for these data details - return [header, endpoint_details_markdown, endpoint_metrics_figure] + return endpoint_metrics_figure # Updates the plugin component when a endpoint row is selected diff --git a/applications/aws_dashboard/pages/endpoints/page.py b/applications/aws_dashboard/pages/endpoints/page.py index 93e2f2918..38db824ae 100644 --- a/applications/aws_dashboard/pages/endpoints/page.py +++ b/applications/aws_dashboard/pages/endpoints/page.py @@ -9,7 +9,7 @@ from . import callbacks # SageWorks Imports -from sageworks.web_components import table, model_details_markdown, endpoint_metric_plots +from sageworks.web_components import table, endpoint_details, endpoint_metric_plots from sageworks.web_components.plugin_interface import PluginPage from sageworks.views.endpoint_web_view import EndpointWebView from sageworks.utils.plugin_manager import PluginManager @@ -33,7 +33,8 @@ ) # Create a Markdown component to display the endpoint details -endpoint_details = model_details_markdown.ModelDetailsMarkdown().create_component("endpoint_details") +endpoint_details = endpoint_details.EndpointDetails() +endpoint_details_component = endpoint_details.create_component("endpoint_details") # Create a component to display the endpoint metrics endpoint_metrics = endpoint_metric_plots.EndpointMetricPlots().create_component("endpoint_metrics") @@ -41,7 +42,7 @@ # Capture our components in a dictionary to send off to the layout components = { "endpoints_table": endpoints_table, - "endpoint_details": endpoint_details, + "endpoint_details": endpoint_details_component, "endpoint_metrics": endpoint_metrics, } @@ -60,10 +61,11 @@ # Setup our callbacks/connections app = dash.get_app() callbacks.update_endpoints_table(app) +endpoint_details.register_callbacks("endpoints_table") # Callback for the endpoints table callbacks.table_row_select(app, "endpoints_table") -callbacks.update_endpoint_details_components(app, endpoint_broker) +callbacks.update_endpoint_metrics(app, endpoint_broker) # For each plugin, set up a callback to update the plugin figure for plugin in plugins: diff --git a/src/sageworks/web_components/endpoint_details.py b/src/sageworks/web_components/endpoint_details.py new file mode 100644 index 000000000..fd08dfa50 --- /dev/null +++ b/src/sageworks/web_components/endpoint_details.py @@ -0,0 +1,160 @@ +"""A Markdown Component for details/information about Endpoints""" + +# Dash Imports +from dash import html, callback, no_update, dcc +from dash.dependencies import Input, Output, State + +# SageWorks Imports +from sageworks.api import Endpoint +from sageworks.web_components.component_interface import ComponentInterface +from sageworks.utils.symbols import health_icons + + +class EndpointDetails(ComponentInterface): + """Model Markdown Component""" + + def __init__(self): + self.prefix_id = "" + self.endpoint = None + super().__init__() + + def create_component(self, component_id: str) -> html.Div: + """Create a Markdown Component without any data. + Args: + component_id (str): The ID of the web component + Returns: + html.Div: A Container of Components for the Model Details + """ + self.prefix_id = component_id + container = html.Div( + id=self.prefix_id, + children=[ + html.H3(id=f"{self.prefix_id}-header", children="Endpoint: Loading..."), + dcc.Markdown(id=f"{self.prefix_id}-details"), + ], + ) + return container + + def register_callbacks(self, endpoint_table): + @callback( + [ + Output(f"{self.prefix_id}-header", "children"), + Output(f"{self.prefix_id}-details", "children"), + ], + Input(endpoint_table, "derived_viewport_selected_row_ids"), + State(endpoint_table, "data"), + ) + def update_endpoint(selected_rows, table_data): + # Check for no selected rows + if not selected_rows or selected_rows[0] is None: + return no_update + + # Get the selected row data, grab the uuid, and set the Model object + selected_row_data = table_data[selected_rows[0]] + endpoint_uuid = selected_row_data["uuid"] + self.endpoint = Endpoint(endpoint_uuid) + + # Update the header, the summary, and the details + header = f"Model: {self.endpoint.uuid}" + details = self.endpoint_details() + + return header, details + + def endpoint_details(self): + """Construct the markdown string for the endpoint details + + Returns: + str: A markdown string + """ + # Get these fields from the endpoint + show_fields = ["health_tags", "input", "status", "instance", "variant"] + + # Construct the markdown string + summary = self.endpoint.details() + markdown = "" + for key in show_fields: + + # Special case for the health tags + if key == "health_tags": + markdown += self._health_tag_markdown(summary.get(key, [])) + continue + + # Get the value + value = summary.get(key, "-") + + # If the value is a list, convert it to a comma-separated string + if isinstance(value, list): + value = ", ".join(value) + + # Chop off the "sageworks_" prefix + key = key.replace("sageworks_", "") + + # Add to markdown string + markdown += f"**{key}:** {value} \n" + + return markdown + + @staticmethod + def _health_tag_markdown(health_tags: list[str]) -> str: + """Internal method to generate the health tag markdown + Args: + health_tags (list[str]): A list of health tags + Returns: + str: A markdown string + """ + # If we have no health tags, then add a bullet for healthy + markdown = "**Health Checks**\n" # Header for Health Checks + + # If we have no health tags, then add a bullet for healthy + if not health_tags: + markdown += f"* Healthy: {health_icons.get('healthy')}\n\n" + return markdown + + # Special case for no_activity with no other tags + if len(health_tags) == 1 and health_tags[0] == "no_activity": + markdown += f"* Healthy: {health_icons.get('healthy')}\n" + markdown += f"* No Activity: {health_icons.get('no_activity')}\n\n" + return markdown + + # If we have health tags, then add a bullet for each tag + markdown += "\n".join(f"* {tag}: {health_icons.get(tag, '')}" for tag in health_tags) + markdown += "\n\n" # Add newlines for separation + return markdown + + +if __name__ == "__main__": + # This class takes in endpoint details and generates a details Markdown component + import dash + import dash_bootstrap_components as dbc + from sageworks.web_components.table import Table + from sageworks.views.artifacts_web_view import ArtifactsWebView + + # Create a endpoint table + endpoints_table = Table().create_component( + "endpoints_table", header_color="rgb(60, 100, 60)", row_select="single", max_height=270 + ) + + # Populate the table with data + view = ArtifactsWebView() + endpoints = view.endpoints_summary() + endpoints["id"] = range(len(endpoints)) + column_setup_list = Table().column_setup(endpoints, markdown_columns=["Name"]) + endpoints_table.columns = column_setup_list + endpoints_table.data = endpoints.to_dict("records") + + # Instantiate the EndpointDetails class + md = EndpointDetails() + details_component = md.create_component("endpoint_details") + + # Register the callbacks + md.register_callbacks("endpoints_table") + + # Initialize Dash app + app = dash.Dash( + __name__, + external_stylesheets=[dbc.themes.DARKLY], + assets_folder="/Users/briford/work/sageworks/applications/aws_dashboard/assets", + ) + + app.layout = html.Div([endpoints_table, details_component]) + app.run_server(debug=True) From 4f6e4a580a381cb62eb8086c12da86aad1afa84b Mon Sep 17 00:00:00 2001 From: Brian Wylie Date: Thu, 28 Mar 2024 10:37:00 -0600 Subject: [PATCH 5/6] flake8 fixes --- applications/aws_dashboard/pages/endpoints/callbacks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/applications/aws_dashboard/pages/endpoints/callbacks.py b/applications/aws_dashboard/pages/endpoints/callbacks.py index db2780183..78d9708f9 100644 --- a/applications/aws_dashboard/pages/endpoints/callbacks.py +++ b/applications/aws_dashboard/pages/endpoints/callbacks.py @@ -5,10 +5,9 @@ # SageWorks Imports from sageworks.views.endpoint_web_view import EndpointWebView -from sageworks.web_components import table, model_details, endpoint_metric_plots +from sageworks.web_components import table, endpoint_metric_plots from sageworks.utils.pandas_utils import deserialize_aws_broker_data from sageworks.api.endpoint import Endpoint -from sageworks.api.model import Model def update_endpoints_table(app: Dash): From b08a21b9db8fe5c5f7ada853952044867d33a13a Mon Sep 17 00:00:00 2001 From: Brian Wylie Date: Thu, 28 Mar 2024 10:53:29 -0600 Subject: [PATCH 6/6] adding the legacy=True flag to read in legacy models --- applications/aws_dashboard/pages/endpoints/callbacks.py | 2 +- applications/aws_dashboard/pages/models/callbacks.py | 4 ++-- src/sageworks/web_components/endpoint_details.py | 2 +- src/sageworks/web_components/model_details.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/applications/aws_dashboard/pages/endpoints/callbacks.py b/applications/aws_dashboard/pages/endpoints/callbacks.py index 78d9708f9..c69754ca8 100644 --- a/applications/aws_dashboard/pages/endpoints/callbacks.py +++ b/applications/aws_dashboard/pages/endpoints/callbacks.py @@ -98,7 +98,7 @@ def update_callback(selected_rows, table_data): endpoint_uuid = selected_row_data["uuid"] # Instantiate the Endpoint and send it to the plugin - endpoint = Endpoint(endpoint_uuid) + endpoint = Endpoint(endpoint_uuid, legacy=True) # Instantiate the Endpoint and send it to the plugin return plugin.generate_figure(endpoint) diff --git a/applications/aws_dashboard/pages/models/callbacks.py b/applications/aws_dashboard/pages/models/callbacks.py index ce9b7ad6a..ac9c160a5 100644 --- a/applications/aws_dashboard/pages/models/callbacks.py +++ b/applications/aws_dashboard/pages/models/callbacks.py @@ -68,7 +68,7 @@ def generate_model_plot_figure(inference_run, table_data, selected_rows): # Get the selected row data and grab the uuid selected_row_data = table_data[selected_rows[0]] model_uuid = selected_row_data["uuid"] - m = Model(model_uuid) + m = Model(model_uuid, legacy=True) # Model Details Markdown component model_plot_fig = model_plot.ModelPlot().generate_figure(m, inference_run) @@ -95,5 +95,5 @@ def update_plugin_figure(selected_rows, table_data): model_uuid = selected_row_data["uuid"] # Instantiate the Model and send it to the plugin - model = Model(model_uuid) + model = Model(model_uuid, legacy=True) return plugin.generate_figure(model) diff --git a/src/sageworks/web_components/endpoint_details.py b/src/sageworks/web_components/endpoint_details.py index fd08dfa50..179ad49b0 100644 --- a/src/sageworks/web_components/endpoint_details.py +++ b/src/sageworks/web_components/endpoint_details.py @@ -52,7 +52,7 @@ def update_endpoint(selected_rows, table_data): # Get the selected row data, grab the uuid, and set the Model object selected_row_data = table_data[selected_rows[0]] endpoint_uuid = selected_row_data["uuid"] - self.endpoint = Endpoint(endpoint_uuid) + self.endpoint = Endpoint(endpoint_uuid, legacy=True) # Update the header, the summary, and the details header = f"Model: {self.endpoint.uuid}" diff --git a/src/sageworks/web_components/model_details.py b/src/sageworks/web_components/model_details.py index 0bae0e712..8d7dd6fe0 100644 --- a/src/sageworks/web_components/model_details.py +++ b/src/sageworks/web_components/model_details.py @@ -57,7 +57,7 @@ def update_model(selected_rows, table_data): # Get the selected row data, grab the uuid, and set the Model object selected_row_data = table_data[selected_rows[0]] model_uuid = selected_row_data["uuid"] - self.model = Model(model_uuid) + self.model = Model(model_uuid, legacy=True) # Update the header, the summary, and the details header = f"Model: {self.model.uuid}"