diff --git a/src/sageworks/core/cloud_platform/aws/aws_meta.py b/src/sageworks/core/cloud_platform/aws/aws_meta.py index 53776a994..cea7199ca 100644 --- a/src/sageworks/core/cloud_platform/aws/aws_meta.py +++ b/src/sageworks/core/cloud_platform/aws/aws_meta.py @@ -9,7 +9,6 @@ import awswrangler as wr from collections import defaultdict - # SageWorks Imports from sageworks.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp from sageworks.utils.config_manager import ConfigManager @@ -127,7 +126,15 @@ def data_sources(self) -> pd.DataFrame: Returns: pd.DataFrame: A summary of the Data Sources deployed in the AWS Platform """ - return self._list_catalog_tables("sageworks") + data_sources_df = self._list_catalog_tables("sageworks") + data_sources_df["Health"] = "" + + # Make the Health column the second column + cols = data_sources_df.columns.tolist() + cols.remove("Health") + cols.insert(1, "Health") + data_sources_df = data_sources_df[cols] + return data_sources_df def views(self, database: str = "sageworks") -> pd.DataFrame: """Get a summary of the all the Views, for the given database, in AWS @@ -169,6 +176,7 @@ def feature_sets(self, details: bool = False) -> pd.DataFrame: aws_tags = self.get_aws_tags(fg["FeatureGroupArn"]) summary = { "Feature Group": name, + "Health": "", "Owner": aws_tags.get("sageworks_owner", "-"), "Created": datetime_string(feature_set_details.get("CreationTime")), "Num Columns": len(feature_set_details.get("FeatureDefinitions", [])), diff --git a/src/sageworks/utils/theme_manager.py b/src/sageworks/utils/theme_manager.py index f0750e503..99c4b3498 100644 --- a/src/sageworks/utils/theme_manager.py +++ b/src/sageworks/utils/theme_manager.py @@ -105,6 +105,28 @@ def colorscale(cls, scale_type: str = "sequential") -> list[list[float | str]]: cls._log.error(f"No color scales found for template '{cls._current_theme_name}'.") return [] + @staticmethod + def adjust_colorscale_alpha(colorscale, alpha=0.5): + """ + Adjust the alpha value of the first color in the colorscale. + + Args: + colorscale (list): The colorscale list with format [[value, color], ...]. + alpha (float): The new alpha value for the first color (0 to 1). + + Returns: + list: The updated colorscale. + """ + updated_colorscale = colorscale.copy() + + if updated_colorscale and "rgba" in updated_colorscale[0][1]: + # Parse the existing RGBA value and modify alpha + rgba_values = updated_colorscale[0][1].strip("rgba()").split(",") + rgba_values[-1] = str(alpha) # Update the alpha channel + updated_colorscale[0][1] = f"rgba({','.join(rgba_values)})" + + return updated_colorscale + @classmethod def css_files(cls) -> list[str]: """Get the list of CSS files for the current theme.""" diff --git a/src/sageworks/web_interface/components/plugins/confusion_matrix.py b/src/sageworks/web_interface/components/plugins/confusion_matrix.py index 86e99daf4..4ad039791 100644 --- a/src/sageworks/web_interface/components/plugins/confusion_matrix.py +++ b/src/sageworks/web_interface/components/plugins/confusion_matrix.py @@ -78,6 +78,8 @@ def update_properties(self, model: CachedModel, **kwargs) -> list: y_labels = [f"{c}:{i}" for i, c in enumerate(df.index)] # Create the heatmap figure + colorscale = self.theme_manager.colorscale() + colorscale = self.theme_manager.adjust_colorscale_alpha(colorscale, alpha=0.25) fig = go.Figure( data=go.Heatmap( z=df, @@ -85,7 +87,7 @@ def update_properties(self, model: CachedModel, **kwargs) -> list: y=y_labels, xgap=3, # Add space between cells ygap=3, - colorscale=self.theme_manager.colorscale(), # Use the current theme's colorscale + colorscale=colorscale, # Use the current theme's colorscale ) ) diff --git a/src/sageworks/web_interface/page_views/data_sources_page_view.py b/src/sageworks/web_interface/page_views/data_sources_page_view.py index dceea60de..43add4fe9 100644 --- a/src/sageworks/web_interface/page_views/data_sources_page_view.py +++ b/src/sageworks/web_interface/page_views/data_sources_page_view.py @@ -6,6 +6,7 @@ from sageworks.web_interface.page_views.page_view import PageView from sageworks.cached.cached_meta import CachedMeta from sageworks.cached.cached_data_source import CachedDataSource +from sageworks.utils.symbols import tag_symbols class DataSourcesPageView(PageView): @@ -26,6 +27,10 @@ def refresh(self): self.log.important("Calling refresh()..") self.data_sources_df = self.meta.data_sources() + # Add Health Symbols to the Model Group Name + if "Health" in self.data_sources_df.columns: + self.data_sources_df["Health"] = self.data_sources_df["Health"].map(lambda x: tag_symbols(x)) + def data_sources(self) -> pd.DataFrame: """Get a list of all the DataSources