Skip to content

Commit

Permalink
Further trimmed plotter and created function in misc_tools to extract…
Browse files Browse the repository at this point in the history
… the contrast ticks to plot
  • Loading branch information
JAnns98 committed Sep 10, 2024
1 parent 6f8cb0d commit 21509c9
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 233 deletions.
6 changes: 5 additions & 1 deletion dabest/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@
'dabest/forest_plot.py'),
'dabest.forest_plot.forest_plot': ('API/forest_plot.html#forest_plot', 'dabest/forest_plot.py'),
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py')},
'dabest.misc_tools': { 'dabest.misc_tools.get_color_palette': ('API/misc_tools.html#get_color_palette', 'dabest/misc_tools.py'),
'dabest.misc_tools': { 'dabest.misc_tools.add_counts_to_ticks': ( 'API/misc_tools.html#add_counts_to_ticks',
'dabest/misc_tools.py'),
'dabest.misc_tools.extract_contrast_plotting_ticks': ( 'API/misc_tools.html#extract_contrast_plotting_ticks',
'dabest/misc_tools.py'),
'dabest.misc_tools.get_color_palette': ('API/misc_tools.html#get_color_palette', 'dabest/misc_tools.py'),
'dabest.misc_tools.get_kwargs': ('API/misc_tools.html#get_kwargs', 'dabest/misc_tools.py'),
'dabest.misc_tools.get_params': ('API/misc_tools.html#get_params', 'dabest/misc_tools.py'),
'dabest.misc_tools.get_plot_groups': ('API/misc_tools.html#get_plot_groups', 'dabest/misc_tools.py'),
Expand Down
71 changes: 70 additions & 1 deletion dabest/misc_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# %% auto 0
__all__ = ['merge_two_dicts', 'unpack_and_add', 'print_greeting', 'get_varname', 'get_params', 'get_kwargs', 'get_color_palette',
'initialize_fig', 'get_plot_groups']
'initialize_fig', 'get_plot_groups', 'add_counts_to_ticks', 'extract_contrast_plotting_ticks']

# %% ../nbs/API/misc_tools.ipynb 4
import datetime as dt
Expand All @@ -11,6 +11,7 @@
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib

# %% ../nbs/API/misc_tools.ipynb 5
def merge_two_dicts(
Expand Down Expand Up @@ -493,3 +494,71 @@ def get_plot_groups(is_paired, idx, proportional, all_plot_groups):
plot_groups = [item for i in temp_idx for item in i]
temp_all_plot_groups = all_plot_groups if not proportional else plot_groups
return temp_idx, temp_all_plot_groups


def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs):
counts = plot_data.groupby(xvar).count()[yvar]
ticks_with_counts = []
ticks_loc = rawdata_axes.get_xticks()
rawdata_axes.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(ticks_loc))
for xticklab in rawdata_axes.xaxis.get_ticklabels():
t = xticklab.get_text()
if t.rfind("\n") != -1:
te = t[t.rfind("\n") + len("\n") :]
N = str(counts.loc[te])
te = t
else:
te = t
N = str(counts.loc[te])

ticks_with_counts.append("{}\nN = {}".format(te, N))

if plot_kwargs["fontsize_rawxlabel"] is not None:
fontsize_rawxlabel = plot_kwargs["fontsize_rawxlabel"]
rawdata_axes.set_xticklabels(ticks_with_counts, fontsize=fontsize_rawxlabel)


def extract_contrast_plotting_ticks(is_paired, show_pairs, two_col_sankey, plot_groups, idx, sankey_control_group):

# Take note of where the `control` groups are.
ticks_to_skip_contrast = None
ticks_to_start_twocol_sankey = None
if is_paired == "baseline" and show_pairs:
if two_col_sankey:
ticks_to_skip = []
ticks_to_plot = np.arange(0, len(plot_groups) / 2).tolist()
ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist()
ticks_to_start_twocol_sankey.pop()
ticks_to_start_twocol_sankey.insert(0, 0)
else:
# ticks_to_skip = np.arange(0, len(temp_all_plot_groups), 2).tolist()
# ticks_to_plot = np.arange(1, len(temp_all_plot_groups), 2).tolist()
ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()
ticks_to_skip.insert(0, 0)
# Then obtain the ticks where we have to plot the effect sizes.
ticks_to_plot = [
t for t in range(0, len(plot_groups)) if t not in ticks_to_skip
]
ticks_to_skip_contrast = np.cumsum([(len(t)) for t in idx])[:-1].tolist()
ticks_to_skip_contrast.insert(0, 0)
else:
if two_col_sankey:
ticks_to_skip = [len(sankey_control_group)]
# Then obtain the ticks where we have to plot the effect sizes.
ticks_to_plot = [
t for t in range(0, len(plot_groups)) if t not in ticks_to_skip
]
ticks_to_skip = []
ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist()
ticks_to_start_twocol_sankey.pop()
ticks_to_start_twocol_sankey.insert(0, 0)
else:
ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()
ticks_to_skip.insert(0, 0)
# Then obtain the ticks where we have to plot the effect sizes.
ticks_to_plot = [
t for t in range(0, len(plot_groups)) if t not in ticks_to_skip
]

return ticks_to_skip, ticks_to_plot, ticks_to_skip_contrast, ticks_to_start_twocol_sankey
...
1 change: 0 additions & 1 deletion dabest/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,6 @@ def plot_minimeta_or_deltadelta_violins(show_mini_meta, effectsize_df, ci_type,
contrast_xtick_labels.extend(["", "delta-delta"])

return contrast_xtick_labels
...

# %% ../nbs/API/plot_tools.ipynb 6
def swarmplot(
Expand Down
129 changes: 16 additions & 113 deletions dabest/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
get_color_palette,
initialize_fig,
get_plot_groups,
add_counts_to_ticks,
extract_contrast_plotting_ticks,
)
from .plot_tools import (
halfviolin,
Expand Down Expand Up @@ -150,7 +152,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
proportional=proportional, float_contrast=float_contrast, face_color=face_color,
h_space_cummings=h_space_cummings)

# Plotting.
# Plotting the rawdata.
if show_pairs:
temp_idx, temp_all_plot_groups = get_plot_groups(is_paired=is_paired, idx=idx, proportional=proportional,
all_plot_groups=all_plot_groups)
Expand Down Expand Up @@ -257,7 +259,6 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):

# Plot the error bars.
if group_summaries is not None:

if proportional:
group_summaries_method = "proportional_error_bar"
group_summaries_offset = 0
Expand Down Expand Up @@ -308,25 +309,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
)

# Add the counts to the rawdata axes xticks.
counts = plot_data.groupby(xvar).count()[yvar]
ticks_with_counts = []
ticks_loc = rawdata_axes.get_xticks()
rawdata_axes.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(ticks_loc))
for xticklab in rawdata_axes.xaxis.get_ticklabels():
t = xticklab.get_text()
if t.rfind("\n") != -1:
te = t[t.rfind("\n") + len("\n") :]
N = str(counts.loc[te])
te = t
else:
te = t
N = str(counts.loc[te])

ticks_with_counts.append("{}\nN = {}".format(te, N))

if plot_kwargs["fontsize_rawxlabel"] is not None:
fontsize_rawxlabel = plot_kwargs["fontsize_rawxlabel"]
rawdata_axes.set_xticklabels(ticks_with_counts, fontsize=fontsize_rawxlabel)
add_counts_to_ticks(plot_data=plot_data, xvar=xvar, yvar=yvar, rawdata_axes=rawdata_axes, plot_kwargs=plot_kwargs)

# Save the handles and labels for the legend.
handles, labels = rawdata_axes.get_legend_handles_labels()
Expand All @@ -335,48 +318,21 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
if bootstraps_color_by_group is False:
rawdata_axes.legend().set_visible(False)

# Enforce the xtick of rawdata_axes to be 0 and 1 after drawing only one sankey
# Enforce the xtick of rawdata_axes to be 0 and 1 after drawing only one sankey ----> Redundant code
if one_sankey:
rawdata_axes.set_xticks([0, 1])

# Plot effect sizes and bootstraps.
# Take note of where the `control` groups are.
if is_paired == "baseline" and show_pairs:
if two_col_sankey:
ticks_to_skip = []
ticks_to_plot = np.arange(0, len(temp_all_plot_groups) / 2).tolist()
ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist()
ticks_to_start_twocol_sankey.pop()
ticks_to_start_twocol_sankey.insert(0, 0)
else:
# ticks_to_skip = np.arange(0, len(temp_all_plot_groups), 2).tolist()
# ticks_to_plot = np.arange(1, len(temp_all_plot_groups), 2).tolist()
ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()
ticks_to_skip.insert(0, 0)
# Then obtain the ticks where we have to plot the effect sizes.
ticks_to_plot = [
t for t in range(0, len(all_plot_groups)) if t not in ticks_to_skip
]
ticks_to_skip_contrast = np.cumsum([(len(t)) for t in idx])[:-1].tolist()
ticks_to_skip_contrast.insert(0, 0)
else:
if two_col_sankey:
ticks_to_skip = [len(sankey_control_group)]
# Then obtain the ticks where we have to plot the effect sizes.
ticks_to_plot = [
t for t in range(0, len(temp_idx)) if t not in ticks_to_skip
]
ticks_to_skip = []
ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist()
ticks_to_start_twocol_sankey.pop()
ticks_to_start_twocol_sankey.insert(0, 0)
else:
ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()
ticks_to_skip.insert(0, 0)
# Then obtain the ticks where we have to plot the effect sizes.
ticks_to_plot = [
t for t in range(0, len(all_plot_groups)) if t not in ticks_to_skip
]
plot_groups = temp_all_plot_groups if (is_paired == "baseline" and show_pairs and two_col_sankey) else temp_idx if (two_col_sankey) else all_plot_groups

(ticks_to_skip, ticks_to_plot, ticks_to_skip_contrast,
ticks_to_start_twocol_sankey) = extract_contrast_plotting_ticks(is_paired=is_paired,
show_pairs=show_pairs,
two_col_sankey=two_col_sankey,
plot_groups=plot_groups,
idx=idx,
sankey_control_group=sankey_control_group if two_col_sankey else None,
)

# Plot the bootstraps, then the effect sizes and CIs.
es_marker_size = plot_kwargs["es_marker_size"]
Expand Down Expand Up @@ -448,59 +404,6 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
contrast_xtick_labels=contrast_xtick_labels, effect_size=effect_size
)

# if show_mini_meta:
# mini_meta_delta = effectsize_df.mini_meta_delta
# data = mini_meta_delta.bootstraps_weighted_delta
# difference = mini_meta_delta.difference
# if ci_type == "bca":
# ci_low = mini_meta_delta.bca_low
# ci_high = mini_meta_delta.bca_high
# else:
# ci_low = mini_meta_delta.pct_low
# ci_high = mini_meta_delta.pct_high
# else:
# delta_delta = effectsize_df.delta_delta
# data = delta_delta.bootstraps_delta_delta
# difference = delta_delta.difference
# if ci_type == "bca":
# ci_low = delta_delta.bca_low
# ci_high = delta_delta.bca_high
# else:
# ci_low = delta_delta.pct_low
# ci_high = delta_delta.pct_high
# # Create the violinplot.
# # New in v0.2.6: drop negative infinities before plotting.
# position = max(rawdata_axes.get_xticks()) + 2
# v = contrast_axes.violinplot(
# data[~np.isinf(data)], positions=[position], **violinplot_kwargs
# )

# fc = "grey"

# halfviolin(v, fill_color=fc, alpha=halfviolin_alpha)

# # Plot the effect size.
# contrast_axes.plot(
# [position],
# difference,
# marker="o",
# color=ytick_color,
# markersize=es_marker_size,
# )
# # Plot the confidence interval.
# contrast_axes.plot(
# [position, position],
# [ci_low, ci_high],
# linestyle="-",
# color=ytick_color,
# linewidth=group_summary_kwargs["lw"],
# )
# if show_mini_meta:
# contrast_xtick_labels.extend(["", "Weighted delta"])
# elif effect_size == "delta_g":
# contrast_xtick_labels.extend(["", "deltas' g"])
# else:
# contrast_xtick_labels.extend(["", "delta-delta"])

# Make sure the contrast_axes x-lims match the rawdata_axes xlims,
# and add an extra violinplot tick for delta-delta plot.
Expand Down Expand Up @@ -1031,7 +934,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
gridkey_rows.append("Ns")
list_of_Ns = []
for i in groups_for_gridkey:
list_of_Ns.append(str(counts.loc[i]))
list_of_Ns.append(str(plot_data.groupby(xvar).count()[yvar].loc[i]))
table_cellcols.append(list_of_Ns)

# Adds a row for effectsizes with effectsize values
Expand Down
73 changes: 71 additions & 2 deletions nbs/API/misc_tools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
"from numpy import repeat\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt"
"import matplotlib.pyplot as plt\n",
"import matplotlib"
]
},
{
Expand Down Expand Up @@ -549,7 +550,75 @@
" # Determine temp_all_plot_groups based on proportional condition\n",
" plot_groups = [item for i in temp_idx for item in i]\n",
" temp_all_plot_groups = all_plot_groups if not proportional else plot_groups\n",
" return temp_idx, temp_all_plot_groups"
" return temp_idx, temp_all_plot_groups\n",
"\n",
"\n",
"def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs):\n",
" counts = plot_data.groupby(xvar).count()[yvar]\n",
" ticks_with_counts = []\n",
" ticks_loc = rawdata_axes.get_xticks()\n",
" rawdata_axes.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(ticks_loc))\n",
" for xticklab in rawdata_axes.xaxis.get_ticklabels():\n",
" t = xticklab.get_text()\n",
" if t.rfind(\"\\n\") != -1:\n",
" te = t[t.rfind(\"\\n\") + len(\"\\n\") :]\n",
" N = str(counts.loc[te])\n",
" te = t\n",
" else:\n",
" te = t\n",
" N = str(counts.loc[te])\n",
"\n",
" ticks_with_counts.append(\"{}\\nN = {}\".format(te, N))\n",
"\n",
" if plot_kwargs[\"fontsize_rawxlabel\"] is not None:\n",
" fontsize_rawxlabel = plot_kwargs[\"fontsize_rawxlabel\"]\n",
" rawdata_axes.set_xticklabels(ticks_with_counts, fontsize=fontsize_rawxlabel)\n",
"\n",
"\n",
"def extract_contrast_plotting_ticks(is_paired, show_pairs, two_col_sankey, plot_groups, idx, sankey_control_group):\n",
"\n",
" # Take note of where the `control` groups are.\n",
" ticks_to_skip_contrast = None\n",
" ticks_to_start_twocol_sankey = None\n",
" if is_paired == \"baseline\" and show_pairs:\n",
" if two_col_sankey:\n",
" ticks_to_skip = []\n",
" ticks_to_plot = np.arange(0, len(plot_groups) / 2).tolist()\n",
" ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist()\n",
" ticks_to_start_twocol_sankey.pop()\n",
" ticks_to_start_twocol_sankey.insert(0, 0)\n",
" else:\n",
" # ticks_to_skip = np.arange(0, len(temp_all_plot_groups), 2).tolist()\n",
" # ticks_to_plot = np.arange(1, len(temp_all_plot_groups), 2).tolist()\n",
" ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()\n",
" ticks_to_skip.insert(0, 0)\n",
" # Then obtain the ticks where we have to plot the effect sizes.\n",
" ticks_to_plot = [\n",
" t for t in range(0, len(plot_groups)) if t not in ticks_to_skip\n",
" ]\n",
" ticks_to_skip_contrast = np.cumsum([(len(t)) for t in idx])[:-1].tolist()\n",
" ticks_to_skip_contrast.insert(0, 0)\n",
" else:\n",
" if two_col_sankey:\n",
" ticks_to_skip = [len(sankey_control_group)]\n",
" # Then obtain the ticks where we have to plot the effect sizes.\n",
" ticks_to_plot = [\n",
" t for t in range(0, len(plot_groups)) if t not in ticks_to_skip\n",
" ]\n",
" ticks_to_skip = []\n",
" ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist()\n",
" ticks_to_start_twocol_sankey.pop()\n",
" ticks_to_start_twocol_sankey.insert(0, 0)\n",
" else:\n",
" ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()\n",
" ticks_to_skip.insert(0, 0)\n",
" # Then obtain the ticks where we have to plot the effect sizes.\n",
" ticks_to_plot = [\n",
" t for t in range(0, len(plot_groups)) if t not in ticks_to_skip\n",
" ]\n",
" \n",
" return ticks_to_skip, ticks_to_plot, ticks_to_skip_contrast, ticks_to_start_twocol_sankey\n",
" ..."
]
}
],
Expand Down
3 changes: 1 addition & 2 deletions nbs/API/plot_tools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1285,8 +1285,7 @@
" else:\n",
" contrast_xtick_labels.extend([\"\", \"delta-delta\"])\n",
" \n",
" return contrast_xtick_labels\n",
" ..."
" return contrast_xtick_labels"
]
},
{
Expand Down
Loading

0 comments on commit 21509c9

Please sign in to comment.