Skip to content

Commit

Permalink
Created a slopegraph_plotter function in plot_tools and trimmed plotter
Browse files Browse the repository at this point in the history
  • Loading branch information
JAnns98 committed Sep 9, 2024
1 parent 3f65777 commit 667909b
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 76 deletions.
2 changes: 2 additions & 0 deletions dabest/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@
'dabest.plot_tools.normalize_dict': ('API/plot_tools.html#normalize_dict', 'dabest/plot_tools.py'),
'dabest.plot_tools.sankeydiag': ('API/plot_tools.html#sankeydiag', 'dabest/plot_tools.py'),
'dabest.plot_tools.single_sankey': ('API/plot_tools.html#single_sankey', 'dabest/plot_tools.py'),
'dabest.plot_tools.slopegraph_plotter': ( 'API/plot_tools.html#slopegraph_plotter',
'dabest/plot_tools.py'),
'dabest.plot_tools.summary_bars_plotter': ( 'API/plot_tools.html#summary_bars_plotter',
'dabest/plot_tools.py'),
'dabest.plot_tools.swarm_bars_plotter': ( 'API/plot_tools.html#swarm_bars_plotter',
Expand Down
43 changes: 42 additions & 1 deletion dabest/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# %% auto 0
__all__ = ['halfviolin', 'get_swarm_spans', 'error_bar', 'check_data_matches_labels', 'normalize_dict', 'width_determine',
'single_sankey', 'sankeydiag', 'summary_bars_plotter', 'contrast_bars_plotter', 'swarm_bars_plotter',
'delta_text_plotter', 'DeltaDotsPlotter', 'swarmplot', 'SwarmPlot']
'delta_text_plotter', 'DeltaDotsPlotter', 'slopegraph_plotter', 'swarmplot', 'SwarmPlot']

# %% ../nbs/API/plot_tools.ipynb 4
import math
Expand Down Expand Up @@ -1108,6 +1108,47 @@ def DeltaDotsPlotter(plot_data, contrast_axes, delta_id_col, idx, xvar, yvar, is
**delta_dot_kwargs)
contrast_axes.legend().set_visible(False)


def slopegraph_plotter(dabest_obj, plot_data, xvar, yvar, color_col, plot_palette_raw, slopegraph_kwargs, rawdata_axes, ytick_color, temp_idx):

# Pivot the long (melted) data.
if color_col is None:
pivot_values = [yvar]
else:
pivot_values = [yvar, color_col]
pivoted_plot_data = pd.pivot(
data=plot_data,
index=dabest_obj.id_col,
columns=xvar,
values=pivot_values,
)

x_start = 0
for ii, current_tuple in enumerate(temp_idx):
current_pair = pivoted_plot_data.loc[
:, pd.MultiIndex.from_product([pivot_values, current_tuple])
].dropna()
grp_count = len(current_tuple)
# Iterate through the data for the current tuple.
for ID, observation in current_pair.iterrows():
x_points = [t for t in range(x_start, x_start + grp_count)]
y_points = observation[yvar].tolist()

if color_col is None:
slopegraph_kwargs["color"] = ytick_color
else:
color_key = observation[color_col][0]
if isinstance(color_key, (str, np.int64, np.float64)):
slopegraph_kwargs["color"] = plot_palette_raw[color_key]
slopegraph_kwargs["label"] = color_key

rawdata_axes.plot(x_points, y_points, **slopegraph_kwargs)

x_start = x_start + grp_count


...

# %% ../nbs/API/plot_tools.ipynb 6
def swarmplot(
data: pd.DataFrame,
Expand Down
44 changes: 7 additions & 37 deletions dabest/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
summary_bars_plotter,
delta_text_plotter,
DeltaDotsPlotter,
slopegraph_plotter,
)
from ._stats_tools.effsize import (
_compute_standardizers,
Expand Down Expand Up @@ -135,19 +136,19 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):

################################################### END GRIDKEY WIP - extracting arguments

################################################### Color palette WIP Start
# Extract Color palette

(color_col, bootstraps_color_by_group, n_groups, swarm_colors, plot_palette_raw, bar_color, plot_palette_bar,
plot_palette_contrast, plot_palette_sankey) = get_color_palette(plot_kwargs=plot_kwargs, plot_data=plot_data,
xvar=xvar, show_pairs=show_pairs)
################################################### Color palette WIP End

# Initialise the figure.
fig, rawdata_axes, contrast_axes, swarm_ylim = initialize_fig(plot_kwargs=plot_kwargs, dabest_obj=dabest_obj, show_delta2=show_delta2,
show_mini_meta=show_mini_meta, is_paired=is_paired, show_pairs=show_pairs,
proportional=proportional, float_contrast=float_contrast, face_color=face_color,
h_space_cummings=h_space_cummings)


# Plotting.
if show_pairs:
# Determine temp_idx based on is_paired and proportional conditions
if is_paired == "baseline":
Expand All @@ -169,40 +170,9 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):

if not proportional:
# Plot the raw data as a slopegraph.
# Pivot the long (melted) data.
if color_col is None:
pivot_values = [yvar]
else:
pivot_values = [yvar, color_col]
pivoted_plot_data = pd.pivot(
data=plot_data,
index=dabest_obj.id_col,
columns=xvar,
values=pivot_values,
)
x_start = 0
for ii, current_tuple in enumerate(temp_idx):
current_pair = pivoted_plot_data.loc[
:, pd.MultiIndex.from_product([pivot_values, current_tuple])
].dropna()
grp_count = len(current_tuple)
# Iterate through the data for the current tuple.
for ID, observation in current_pair.iterrows():
x_points = [t for t in range(x_start, x_start + grp_count)]
y_points = observation[yvar].tolist()

if color_col is None:
slopegraph_kwargs["color"] = ytick_color
else:
color_key = observation[color_col][0]
if isinstance(color_key, (str, np.int64, np.float64)):
slopegraph_kwargs["color"] = plot_palette_raw[color_key]
slopegraph_kwargs["label"] = color_key

rawdata_axes.plot(x_points, y_points, **slopegraph_kwargs)

x_start = x_start + grp_count

slopegraph_plotter(dabest_obj=dabest_obj, plot_data=plot_data, xvar=xvar, yvar=yvar, color_col=color_col,
plot_palette_raw=plot_palette_raw, slopegraph_kwargs=slopegraph_kwargs, rawdata_axes=rawdata_axes,
ytick_color=ytick_color, temp_idx=temp_idx)

##################### DELTA PTS ON CONTRAST PLOT WIP
show_delta_dots = plot_kwargs["delta_dot"]
Expand Down
43 changes: 42 additions & 1 deletion nbs/API/plot_tools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,48 @@
" is_drop_gutter=True,\n",
" gutter_limit=1,\n",
" **delta_dot_kwargs)\n",
" contrast_axes.legend().set_visible(False)"
" contrast_axes.legend().set_visible(False)\n",
"\n",
"\n",
"def slopegraph_plotter(dabest_obj, plot_data, xvar, yvar, color_col, plot_palette_raw, slopegraph_kwargs, rawdata_axes, ytick_color, temp_idx):\n",
" \n",
" # Pivot the long (melted) data.\n",
" if color_col is None:\n",
" pivot_values = [yvar]\n",
" else:\n",
" pivot_values = [yvar, color_col]\n",
" pivoted_plot_data = pd.pivot(\n",
" data=plot_data,\n",
" index=dabest_obj.id_col,\n",
" columns=xvar,\n",
" values=pivot_values,\n",
" )\n",
"\n",
" x_start = 0\n",
" for ii, current_tuple in enumerate(temp_idx):\n",
" current_pair = pivoted_plot_data.loc[\n",
" :, pd.MultiIndex.from_product([pivot_values, current_tuple])\n",
" ].dropna()\n",
" grp_count = len(current_tuple)\n",
" # Iterate through the data for the current tuple.\n",
" for ID, observation in current_pair.iterrows():\n",
" x_points = [t for t in range(x_start, x_start + grp_count)]\n",
" y_points = observation[yvar].tolist()\n",
"\n",
" if color_col is None:\n",
" slopegraph_kwargs[\"color\"] = ytick_color\n",
" else:\n",
" color_key = observation[color_col][0]\n",
" if isinstance(color_key, (str, np.int64, np.float64)):\n",
" slopegraph_kwargs[\"color\"] = plot_palette_raw[color_key]\n",
" slopegraph_kwargs[\"label\"] = color_key\n",
"\n",
" rawdata_axes.plot(x_points, y_points, **slopegraph_kwargs)\n",
"\n",
" x_start = x_start + grp_count\n",
"\n",
"\n",
" ..."
]
},
{
Expand Down
44 changes: 7 additions & 37 deletions nbs/API/plotter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
" summary_bars_plotter,\n",
" delta_text_plotter,\n",
" DeltaDotsPlotter,\n",
" slopegraph_plotter,\n",
" )\n",
" from ._stats_tools.effsize import (\n",
" _compute_standardizers,\n",
Expand Down Expand Up @@ -194,19 +195,19 @@
"\n",
" ################################################### END GRIDKEY WIP - extracting arguments\n",
"\n",
" ################################################### Color palette WIP Start\n",
" # Extract Color palette\n",
"\n",
" (color_col, bootstraps_color_by_group, n_groups, swarm_colors, plot_palette_raw, bar_color, plot_palette_bar, \n",
" plot_palette_contrast, plot_palette_sankey) = get_color_palette(plot_kwargs=plot_kwargs, plot_data=plot_data, \n",
" xvar=xvar, show_pairs=show_pairs)\n",
" ################################################### Color palette WIP End\n",
"\n",
" # Initialise the figure.\n",
" fig, rawdata_axes, contrast_axes, swarm_ylim = initialize_fig(plot_kwargs=plot_kwargs, dabest_obj=dabest_obj, show_delta2=show_delta2, \n",
" show_mini_meta=show_mini_meta, is_paired=is_paired, show_pairs=show_pairs, \n",
" proportional=proportional, float_contrast=float_contrast, face_color=face_color, \n",
" h_space_cummings=h_space_cummings)\n",
"\n",
" \n",
" # Plotting.\n",
" if show_pairs:\n",
" # Determine temp_idx based on is_paired and proportional conditions\n",
" if is_paired == \"baseline\":\n",
Expand All @@ -228,40 +229,9 @@
"\n",
" if not proportional:\n",
" # Plot the raw data as a slopegraph.\n",
" # Pivot the long (melted) data.\n",
" if color_col is None:\n",
" pivot_values = [yvar]\n",
" else:\n",
" pivot_values = [yvar, color_col]\n",
" pivoted_plot_data = pd.pivot(\n",
" data=plot_data,\n",
" index=dabest_obj.id_col,\n",
" columns=xvar,\n",
" values=pivot_values,\n",
" )\n",
" x_start = 0\n",
" for ii, current_tuple in enumerate(temp_idx):\n",
" current_pair = pivoted_plot_data.loc[\n",
" :, pd.MultiIndex.from_product([pivot_values, current_tuple])\n",
" ].dropna()\n",
" grp_count = len(current_tuple)\n",
" # Iterate through the data for the current tuple.\n",
" for ID, observation in current_pair.iterrows():\n",
" x_points = [t for t in range(x_start, x_start + grp_count)]\n",
" y_points = observation[yvar].tolist()\n",
"\n",
" if color_col is None:\n",
" slopegraph_kwargs[\"color\"] = ytick_color\n",
" else:\n",
" color_key = observation[color_col][0]\n",
" if isinstance(color_key, (str, np.int64, np.float64)):\n",
" slopegraph_kwargs[\"color\"] = plot_palette_raw[color_key]\n",
" slopegraph_kwargs[\"label\"] = color_key\n",
"\n",
" rawdata_axes.plot(x_points, y_points, **slopegraph_kwargs)\n",
"\n",
" x_start = x_start + grp_count\n",
"\n",
" slopegraph_plotter(dabest_obj=dabest_obj, plot_data=plot_data, xvar=xvar, yvar=yvar, color_col=color_col, \n",
" plot_palette_raw=plot_palette_raw, slopegraph_kwargs=slopegraph_kwargs, rawdata_axes=rawdata_axes, \n",
" ytick_color=ytick_color, temp_idx=temp_idx)\n",
"\n",
" ##################### DELTA PTS ON CONTRAST PLOT WIP\n",
" show_delta_dots = plot_kwargs[\"delta_dot\"]\n",
Expand Down

0 comments on commit 667909b

Please sign in to comment.