Skip to content

Commit

Permalink
setting up themes for scatter plot
Browse files Browse the repository at this point in the history
  • Loading branch information
brifordwylie committed Dec 6, 2024
1 parent 5cedd4b commit 317dac6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,4 @@ def highlight_cm_square(click_data, current_figure):

# Run the Unit Test on the Plugin
model = CachedModel("wine-classification")
PluginUnitTest(ConfusionMatrix, input_data=model, theme="light").run()
PluginUnitTest(ConfusionMatrix, input_data=model, theme="dark").run()
28 changes: 8 additions & 20 deletions src/sageworks/web_interface/components/plugins/scatter_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# SageWorks Imports
from sageworks.api import DataSource, FeatureSet
from sageworks.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
from sageworks.utils.theme_manager import ThemeManager


class ScatterPlot(PluginInterface):
Expand All @@ -22,6 +23,9 @@ def __init__(self):
self.hover_columns = []
self.df = None

# Initialize the Theme Manager
self.theme_manager = ThemeManager()

# Call the parent class constructor
super().__init__()

Expand Down Expand Up @@ -178,14 +182,6 @@ def create_scatter_plot(self, df, x_col, y_col, color_col, regression_line=False
go.Figure: A Plotly Figure object.
"""

# Define a custom color scale (blue -> yellow -> orange -> red)
color_scale = [
[0.0, "rgb(64, 64, 160)"],
[0.33, "rgb(48, 140, 140)"],
[0.67, "rgb(140, 140, 48)"],
[1.0, "rgb(160, 64, 64)"],
]

# Create an OpenGL Scatter Plot
figure = go.Figure(
data=go.Scattergl(
Expand All @@ -200,7 +196,7 @@ def create_scatter_plot(self, df, x_col, y_col, color_col, regression_line=False
marker=dict(
size=15,
color=df[color_col], # Use the selected field for color
colorscale=color_scale,
colorscale=self.theme_manager.colorscale(),
colorbar=dict(title=color_col),
opacity=df[color_col].apply(
lambda x: 0.25 + 0.74 * (x - df[color_col].min()) / (df[color_col].max() - df[color_col].min())
Expand All @@ -216,36 +212,28 @@ def create_scatter_plot(self, df, x_col, y_col, color_col, regression_line=False
max_val = max(df[x_col].max(), df[y_col].max())
figure.add_shape(
type="line",
line=dict(width=4, color="rgba(1.0, 1.0, 1.0, 0.25)"),
line=dict(width=4, color="rgba(0.5, 0.5, 0.5, 0.5)"),
x0=min_val,
x1=max_val,
y0=min_val,
y1=max_val,
)

# Apply the selected theme and set transparent background
plotly_theme = "plotly_dark" if self.dark_theme else "plotly"
# Update the layout
figure.update_layout(
template=plotly_theme,
margin={"t": 40, "b": 40, "r": 40, "l": 40, "pad": 0},
xaxis=dict(
title=x_col,
tickformat=".2f",
showgrid=True,
gridcolor="rgba(100,100,100,0.25)", # Medium grey
zerolinecolor="rgba(80, 80, 150, 0.5)", # Blue color for the zero line
),
yaxis=dict(
title=y_col,
tickformat=".2f",
showgrid=True,
gridcolor="rgba(100,100,100,0.25)", # Medium grey
zerolinecolor="rgba(80, 80, 150, 0.5)", # Blue color for the zero line
),
showlegend=False, # Remove legend
dragmode="pan",
paper_bgcolor="rgba(0,0,0,0)", # Set the paper background to transparent
plot_bgcolor="rgba(0,0,0,0)", # Set the plot background to transparent
)
return figure

Expand Down Expand Up @@ -279,4 +267,4 @@ def update_graph(x_value, y_value, color_value, regression_line):
from sageworks.web_interface.components.plugin_unit_test import PluginUnitTest

# Run the Unit Test on the Plugin
PluginUnitTest(ScatterPlot).run()
PluginUnitTest(ScatterPlot, theme="light").run()

0 comments on commit 317dac6

Please sign in to comment.