Skip to content

Commit

Permalink
Added barplot functionality to horizontal plots
Browse files Browse the repository at this point in the history
  • Loading branch information
JAnns98 committed Sep 27, 2024
1 parent f7bf14f commit 2f72fc0
Show file tree
Hide file tree
Showing 9 changed files with 502 additions and 187 deletions.
4 changes: 4 additions & 0 deletions dabest/_effsize_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,7 @@ def plot(

# Horizontal Plots
horizontal=False,
horizontal_table_kwargs=None,
):
"""
Creates an estimation plot for the effect size of interest.
Expand Down Expand Up @@ -1212,6 +1213,9 @@ def plot(
horizontal : boolean, default False
Whether or not to plot the effect size plot in a horizontal format.
horizontal_table_kwargs : dict, default None
Pass relevant keyword arguments to the horizontal table. If None, the following keywords are passed:
{'color' : 'yellow', 'alpha' :0.2, 'fontsize' : 12, 'text_color' : 'black', 'text_units' : None, 'paired_gap_dashes' : False, 'fontsize_label': 12}
Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions dabest/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@
'dabest.plot_tools.swarm_bars_plotter': ( 'API/plot_tools.html#swarm_bars_plotter',
'dabest/plot_tools.py'),
'dabest.plot_tools.swarmplot': ('API/plot_tools.html#swarmplot', 'dabest/plot_tools.py'),
'dabest.plot_tools.table_for_horizontal_plots': ( 'API/plot_tools.html#table_for_horizontal_plots',
'dabest/plot_tools.py'),
'dabest.plot_tools.width_determine': ('API/plot_tools.html#width_determine', 'dabest/plot_tools.py')},
'dabest.plotter': { 'dabest.plotter.effectsize_df_plotter': ('API/plotter.html#effectsize_df_plotter', 'dabest/plotter.py'),
'dabest.plotter.effectsize_df_plotter_horizontal': ( 'API/plotter.html#effectsize_df_plotter_horizontal',
Expand Down
55 changes: 32 additions & 23 deletions dabest/misc_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,17 @@ def get_kwargs(plot_kwargs, ytick_color):
else:
contrast_bars_kwargs = merge_two_dicts(default_contrast_bars_kwargs, plot_kwargs["contrast_bars_kwargs"])

# Table axes for horizontal plot kwargs.
default_table_kwargs = {'color' : 'yellow','alpha' :0.2,'fontsize' : 12,'text_color' : 'black', 'text_units' : None,'paired_gap_dashes' : False,
'fontsize_label': 12}
if plot_kwargs["horizontal_table_kwargs"] is None:
table_kwargs = default_table_kwargs
else:
table_kwargs = merge_two_dicts(default_table_kwargs, plot_kwargs["horizontal_table_kwargs"])

return (swarmplot_kwargs, barplot_kwargs, sankey_kwargs, violinplot_kwargs, slopegraph_kwargs,
reflines_kwargs, legend_kwargs, group_summary_kwargs, redraw_axes_kwargs, delta_dot_kwargs,
delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs)
delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs, table_kwargs)


def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_groups):
Expand Down Expand Up @@ -987,7 +995,6 @@ def Cumming_Plot_Aesthetic_Adjustments(plot_kwargs, show_delta2, effect_size_typ

# If 0 lies within the ylim of the contrast axes,
# draw a zero reference line.

if horizontal:
contrast_axes_xlim = contrast_axes.get_xlim()
if contrast_axes_xlim[0] < contrast_axes_xlim[1]:
Expand Down Expand Up @@ -1029,7 +1036,7 @@ def Cumming_Plot_Aesthetic_Adjustments(plot_kwargs, show_delta2, effect_size_typ
end_tick = rightend_ticks_raw[k]
ax.vlines(ymin=start_tick, ymax=end_tick, **redraw_axes_kwargs)
ax.set_xlim(xlim)
del redraw_axes_kwargs["x"]
del redraw_axes_kwargs["x"]
else:
# Compute the end of each x-axes line.
if two_col_sankey:
Expand All @@ -1040,7 +1047,6 @@ def Cumming_Plot_Aesthetic_Adjustments(plot_kwargs, show_delta2, effect_size_typ
rightend_ticks = np.array([len(i) - 1 for i in idx]) + np.array(
ticks_to_skip
)

for ax in [rawdata_axes]:
sns.despine(ax=ax, left=True)

Expand All @@ -1058,6 +1064,7 @@ def Cumming_Plot_Aesthetic_Adjustments(plot_kwargs, show_delta2, effect_size_typ
ax.vlines(ymin=start_tick, ymax=end_tick, **redraw_axes_kwargs)

ax.set_xlim(xlim)
ax.set_ylim(ylim)
del redraw_axes_kwargs["x"]

# Remove y ticks and labels from the contrast axes.
Expand Down Expand Up @@ -1240,21 +1247,21 @@ def General_Plot_Aesthetic_Adjustments(show_delta2, show_mini_meta, contrast_axe

# Because we turned the axes frame off, we also need to draw back
# the x-spine for both axes.
rawdata_axes.set_ylim(contrast_axes.get_ylim())
og_xlim_raw = og_ylim_raw
new_ylim_raw = rawdata_axes.get_ylim()

rawdata_axes.hlines(
new_ylim_raw[1], og_xlim_raw[0], og_xlim_raw[1], **redraw_axes_kwargs
)
# rawdata_axes.set_ylim(contrast_axes.get_ylim())
# og_xlim_raw = og_ylim_raw
# new_ylim_raw = rawdata_axes.get_ylim()

og_ylim_contrast = contrast_axes.get_ylim()
ypos = og_ylim_contrast[1]
# spine_ylim = 1 if not proportional else 0
# rawdata_axes.hlines(
# new_ylim_raw[spine_ylim], og_xlim_raw[0], og_xlim_raw[1], **redraw_axes_kwargs
# )
# og_ylim_contrast = contrast_axes.get_ylim()
# ypos = og_ylim_contrast[spine_ylim]

og_xlim_contrast = contrast_axes.get_xlim()
contrast_axes.hlines(
ypos, og_xlim_contrast[0], og_xlim_contrast[1], **redraw_axes_kwargs
)
# og_xlim_contrast = contrast_axes.get_xlim()
# contrast_axes.hlines(
# ypos, og_xlim_contrast[0], og_xlim_contrast[1], **redraw_axes_kwargs
# )

if show_delta2:
if plot_kwargs["delta2_label"] is not None:
Expand All @@ -1265,14 +1272,16 @@ def General_Plot_Aesthetic_Adjustments(show_delta2, show_mini_meta, contrast_axe
delta2_label = "deltas' g"
raise NotImplementedError("Delta2 is not yet supported for horizontal plots.")

# Set custom ylims, if they were specified.
swarm_ylim = plot_kwargs["swarm_ylim"]
contrast_ylim = plot_kwargs["contrast_ylim"]
if swarm_ylim is None:
swarm_ylim = rawdata_axes.get_ylim()
rawdata_axes.set_ylim(swarm_ylim[1], swarm_ylim[0])
if plot_kwargs['contrast_ylim'] is None:
contrast_ylim = contrast_axes.get_ylim()
contrast_axes.set_ylim(contrast_ylim[1], contrast_ylim[0])
if not proportional:
if swarm_ylim is None:
swarm_ylim = rawdata_axes.get_ylim()
rawdata_axes.set_ylim(swarm_ylim[1], swarm_ylim[0])
if plot_kwargs['contrast_ylim'] is None:
contrast_ylim = contrast_axes.get_ylim()
contrast_axes.set_ylim(contrast_ylim[1], contrast_ylim[0])

else:
contrast_axes.set_ylabel(contrast_label, fontsize=fontsize_contrastylabel)
Expand Down
198 changes: 155 additions & 43 deletions dabest/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
__all__ = ['halfviolin', 'get_swarm_spans', 'error_bar', 'check_data_matches_labels', 'normalize_dict', 'width_determine',
'single_sankey', 'sankeydiag', 'summary_bars_plotter', 'contrast_bars_plotter', 'swarm_bars_plotter',
'delta_text_plotter', 'DeltaDotsPlotter', 'slopegraph_plotter', 'plot_minimeta_or_deltadelta_violins',
'effect_size_curve_plotter', 'grid_key_WIP', 'barplotter', 'swarmplot', 'SwarmPlot']
'effect_size_curve_plotter', 'grid_key_WIP', 'barplotter', 'table_for_horizontal_plots', 'swarmplot',
'SwarmPlot']

# %% ../nbs/API/plot_tools.ipynb 4
import math
Expand Down Expand Up @@ -100,7 +101,11 @@ def error_bar(

if ax is None:
ax = plt.gca()
ax_ylims = ax.get_ylim()

if horizontal:
ax_ylims = ax.get_xlim()
else:
ax_ylims = ax.get_ylim()
ax_yspan = np.abs(ax_ylims[1] - ax_ylims[0])
gap_width = ax_yspan * gap_width_percent / 100

Expand Down Expand Up @@ -898,7 +903,8 @@ def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object
def contrast_bars_plotter(results: object, ax_to_plot: object, swarm_plot_ax: object,
ticks_to_plot: list, contrast_bars_kwargs: dict, color_col: str,
plot_palette_raw: dict, show_mini_meta: bool, mini_meta_delta: object,
show_delta2: bool, delta_delta: object, proportional: bool, is_paired: bool):
show_delta2: bool, delta_delta: object, proportional: bool, is_paired: bool,
horizontal: bool = False):
"""
Add contrast bars to the contrast plot.
Expand Down Expand Up @@ -930,6 +936,8 @@ def contrast_bars_plotter(results: object, ax_to_plot: object, swarm_plot_ax: o
Whether the data is proportional.
is_paired : bool
Whether the data is paired.
horizontal : bool
Whether the plot is horizontal.
"""
contrast_means = []
for j, tick in enumerate(ticks_to_plot):
Expand All @@ -943,14 +951,26 @@ def contrast_bars_plotter(results: object, ax_to_plot: object, swarm_plot_ax: o
else list(plot_palette_raw.values())
)
contrast_bars_kwargs.pop('color')
for contrast_bars_x,contrast_bars_y in zip(ticks_to_plot, contrast_means):
ax_to_plot.add_patch(mpatches.Rectangle((contrast_bars_x-0.25,0),0.5, contrast_bars_y, zorder=-1, color=contrast_bars_colors[contrast_bars_x], **contrast_bars_kwargs))

if show_mini_meta:
ax_to_plot.add_patch(mpatches.Rectangle((max(swarm_plot_ax.get_xticks())+2-0.25,0),0.5, mini_meta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs))
if horizontal:
for contrast_bars_x,contrast_bars_y in zip(ticks_to_plot, contrast_means):
ax_to_plot.add_patch(mpatches.Rectangle((0,contrast_bars_x-0.5),contrast_bars_y, 0.5, zorder=-10, color=contrast_bars_colors[contrast_bars_x], **contrast_bars_kwargs))

if show_mini_meta:
ax_to_plot.add_patch(mpatches.Rectangle((0, max(swarm_plot_ax.get_yticks())+2-0.5), mini_meta_delta.difference, 0.5, zorder=-10, color='black', **contrast_bars_kwargs))

if show_delta2:
ax_to_plot.add_patch(mpatches.Rectangle((max(swarm_plot_ax.get_xticks())+2-0.25,0),0.5, delta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs))
if show_delta2:
ax_to_plot.add_patch(mpatches.Rectangle((0, max(swarm_plot_ax.get_yticks())+2-0.5), delta_delta.difference, 0.5, zorder=-10, color='black', **contrast_bars_kwargs))

else:
for contrast_bars_x,contrast_bars_y in zip(ticks_to_plot, contrast_means):
ax_to_plot.add_patch(mpatches.Rectangle((contrast_bars_x-0.25,0),0.5, contrast_bars_y, zorder=-1, color=contrast_bars_colors[contrast_bars_x], **contrast_bars_kwargs))

if show_mini_meta:
ax_to_plot.add_patch(mpatches.Rectangle((max(swarm_plot_ax.get_xticks())+2-0.25,0),0.5, mini_meta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs))

if show_delta2:
ax_to_plot.add_patch(mpatches.Rectangle((max(swarm_plot_ax.get_xticks())+2-0.25,0),0.5, delta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs))

def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object,
swarm_bars_kwargs: dict, color_col: str, plot_palette_raw: dict, is_paired: bool):
Expand Down Expand Up @@ -1499,40 +1519,132 @@ def grid_key_WIP(is_paired, idx, all_plot_groups, gridkey_rows, rawdata_axes, co
rawdata_axes.get_xaxis().set_visible(False)
contrast_axes.get_xaxis().set_visible(False)

def barplotter(xvar, yvar, all_plot_groups, rawdata_axes, plot_data, bar_color, plot_palette_bar, plot_kwargs, barplot_kwargs):
# Plot the raw data as a barplot.
bar1_df = pd.DataFrame(
{xvar: all_plot_groups, "proportion": np.ones(len(all_plot_groups))}
)
bar1 = sns.barplot(
data=bar1_df,
x=xvar,
y="proportion",
ax=rawdata_axes,
order=all_plot_groups,
linewidth=2,
facecolor=(1, 1, 1, 0),
edgecolor=bar_color,
zorder=1,
)
bar2 = sns.barplot(
data=plot_data,
x=xvar,
y=yvar,
ax=rawdata_axes,
order=all_plot_groups,
palette=plot_palette_bar,
zorder=1,
**barplot_kwargs
)
# adjust the width of bars
bar_width = plot_kwargs["bar_width"]
for bar in bar1.patches:
x = bar.get_x()
width = bar.get_width()
centre = x + width / 2.0
bar.set_x(centre - bar_width / 2.0)
bar.set_width(bar_width)
def barplotter(xvar, yvar, all_plot_groups, rawdata_axes, plot_data, bar_color, plot_palette_bar,
plot_kwargs, barplot_kwargs, horizontal=False):

if horizontal:
# Plot the raw data as a barplot.
bar1_df = pd.DataFrame(
{xvar: np.ones(len(all_plot_groups)), "proportion": all_plot_groups}
)
bar1 = sns.barplot(
data=bar1_df,
x=xvar,
y="proportion",
ax=rawdata_axes,
order=all_plot_groups,
linewidth=2,
facecolor=(1, 1, 1, 0),
edgecolor=bar_color,
zorder=1,
orient="h",
)
bar2 = sns.barplot(
data=plot_data,
x=yvar,
y=xvar,
ax=rawdata_axes,
order=all_plot_groups,
palette=plot_palette_bar,
zorder=1,
orient="h",
**barplot_kwargs
)
# adjust the width of bars
bar_width = plot_kwargs["bar_width"]
for bar in bar1.patches:
y = bar.get_y()
height = bar.get_height()
centre = y + height / 2.0
bar.set_y(centre - bar_width / 2.0)
bar.set_height(bar_width)
else:
# Plot the raw data as a barplot.
bar1_df = pd.DataFrame(
{xvar: all_plot_groups, "proportion": np.ones(len(all_plot_groups))}
)
bar1 = sns.barplot(
data=bar1_df,
x=xvar,
y="proportion",
ax=rawdata_axes,
order=all_plot_groups,
linewidth=2,
facecolor=(1, 1, 1, 0),
edgecolor=bar_color,
zorder=1,
)
bar2 = sns.barplot(
data=plot_data,
x=xvar,
y=yvar,
ax=rawdata_axes,
order=all_plot_groups,
palette=plot_palette_bar,
zorder=1,
**barplot_kwargs
)
# adjust the width of bars
bar_width = plot_kwargs["bar_width"]
for bar in bar1.patches:
x = bar.get_x()
width = bar.get_width()
centre = x + width / 2.0
bar.set_x(centre - bar_width / 2.0)
bar.set_width(bar_width)

def table_for_horizontal_plots(effectsize_df, ax, contrast_axes, ticks_to_plot, show_mini_meta, show_delta2, table_kwargs):
table_color = table_kwargs['color']
table_alpha = table_kwargs['alpha']
table_font_size = table_kwargs['fontsize'] if table_kwargs['text_units'] == None else table_kwargs['fontsize']-2
table_text_color = table_kwargs['text_color']
text_units = '' if table_kwargs['text_units'] == None else table_kwargs['text_units']
table_gap_dashes = table_kwargs['paired_gap_dashes']
fontsize_label = table_kwargs['fontsize_label']

### Create a table of deltas
cols=['Δ','N']
lst = []
for n in np.arange(0, len(effectsize_df.results.difference), 1):
lst.append([effectsize_df.results.difference[n],0])
if show_mini_meta:
lst.append([effectsize_df.mini_meta_delta.difference,0])
elif show_delta2:
lst.append([effectsize_df.delta_delta.difference,0])
tab = pd.DataFrame(lst, columns=cols)


### Plot the text
for i,loc in zip(tab.index, ticks_to_plot):
if show_mini_meta or show_delta2:
loc_new = loc if loc != 0.25 else loc+0.25
ax.text(0.5, loc_new, "{:+.2f}".format(tab.iloc[i,0])+text_units,ha="center", va="center", color=table_text_color,size=table_font_size)
else:
ax.text(0.5, loc, "{:+.2f}".format(tab.iloc[i,0])+text_units,ha="center", va="center", color=table_text_color,size=table_font_size)

# ### Plot the dashes
# if show_mini_meta or show_delta2:
# no_contrast_positions = list(set([int(x-0.5) for x in ticks_to_plot[:-1]]) ^ set(np.arange(2,Num_Exps+2,1)))
# else:
# no_contrast_positions = list(set([int(x-0.5) for x in ypos]) ^ set(np.arange(0,Num_Exps,1)))

# if table_gap_dashes or not is_paired or multi_paired_control:
# if not (mini_meta or delta2):
# for i in no_contrast_positions:
# rawdata_axes.table_axes.text(0.5, i+1, "—",ha="center", va="center", color=table_text_color,size=table_font_size)


### Parameters for table
ax.axvspan(0, 1, facecolor=table_color, alpha=table_alpha) #### Plot the background color
ax.set_xticks([0.5])
ax.set_xticklabels([])
ax.set_ylim(contrast_axes.get_ylim())
ax.set_yticks([])
ax.set_yticklabels([])
ax.tick_params(left=False, bottom=False)
ax.set_xlabel('Δ', fontsize=fontsize_label) # Set the x-axis label - hardcoded for now
sns.despine(ax=ax, left=True, bottom=True)
...

# %% ../nbs/API/plot_tools.ipynb 6
def swarmplot(
Expand Down
Loading

0 comments on commit 2f72fc0

Please sign in to comment.