Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forest plot API and Tutorial changes #185

Merged
merged 20 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .cursorignore
Jacobluke- marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv)
# Ignore all common image file formats
*.jpg
*.jpeg
*.png
*.gif
*.bmp
*.tiff
*.webp
*.svg
4 changes: 3 additions & 1 deletion dabest/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@
'dabest.forest_plot': { 'dabest.forest_plot.extract_plot_data': ( 'API/forest_plot.html#extract_plot_data',
'dabest/forest_plot.py'),
'dabest.forest_plot.forest_plot': ('API/forest_plot.html#forest_plot', 'dabest/forest_plot.py'),
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py')},
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py'),
'dabest.forest_plot.map_effect_attribute': ( 'API/forest_plot.html#map_effect_attribute',
'dabest/forest_plot.py')},
'dabest.misc_tools': { 'dabest.misc_tools.Cumming_Plot_Aesthetic_Adjustments': ( 'API/misc_tools.html#cumming_plot_aesthetic_adjustments',
'dabest/misc_tools.py'),
'dabest.misc_tools.Gardner_Altman_Plot_Aesthetic_Adjustments': ( 'API/misc_tools.html#gardner_altman_plot_aesthetic_adjustments',
Expand Down
81 changes: 52 additions & 29 deletions dabest/forest_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/forest_plot.ipynb.

# %% auto 0
__all__ = ['load_plot_data', 'extract_plot_data', 'forest_plot']
__all__ = ['load_plot_data', 'extract_plot_data', 'map_effect_attribute', 'forest_plot']

# %% ../nbs/API/forest_plot.ipynb 5
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -74,28 +74,42 @@ def extract_plot_data(contrast_plot_data, contrast_type):

return bootstraps, differences, bcalows, bcahighs

def map_effect_attribute(attribute_key):
# Check if the attribute key exists in the dictionary
effect_attr_map = {
"mean_diff": "Mean Difference",
"median_diff": "Median Difference",
"cliffs_delta": "Cliffs Delta",
"cohens_d": "Cohens d",
"hedges_g": "Hedges g",
"delta_g": "Delta g"
}
if attribute_key in effect_attr_map:
return effect_attr_map[attribute_key]
else:
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.") # Return a default value or message if the key is not found

def forest_plot(
contrasts: List,
selected_indices: Optional[List] = None,
contrast_type: str = "delta2",
xticklabels: Optional[List] = None,
effect_size: str = "mean_diff",
contrast_labels: List[str] = None,
ylabel: str = "value",
ylabel: str = "effect size",
plot_elements_to_extract: Optional[List] = None,
title: str = "ΔΔ Forest",
custom_palette: Optional[Union[dict, list, str]] = None,
fontsize: int = 20,
fontsize: int = 12,
title_font_size: int =16,
violin_kwargs: Optional[dict] = None,
marker_size: int = 20,
ci_line_width: float = 2.5,
zero_line_width: int = 1,
desat_violin: float = 1,
remove_spines: bool = True,
ax: Optional[plt.Axes] = None,
additional_plotting_kwargs: Optional[dict] = None,
rotation_for_xlabels: int = 45,
alpha_violin_plot: float = 0.4,
alpha_violin_plot: float = 0.8,
horizontal: bool = False # New argument for horizontal orientation
)-> plt.Figure:
"""
Expand All @@ -108,11 +122,9 @@ def forest_plot(
selected_indices : Optional[List], default=None
Indices of specific contrasts to plot, if not plotting all.
analysis_type : str
the type of analysis (e.g., 'delta2', 'minimeta').
xticklabels : Optional[List], default=None
Custom labels for the x-axis ticks.
the type of analysis (e.g., 'delta2', 'mini_meta').
effect_size : str
Type of effect size to plot (e.g., 'mean_diff', 'median_diff').
Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).
contrast_labels : List[str]
Labels for each contrast.
ylabel : str
Expand All @@ -127,14 +139,14 @@ def forest_plot(
Custom color palette for the plot.
fontsize : int
Font size for text elements in the plot.
title_font_size: int =16
Font size for text of plot title.
violin_kwargs : Optional[dict], default=None
Additional arguments for violin plot customization.
marker_size : int
Marker size for plotting mean differences or effect sizes.
ci_line_width : float
Width of confidence interval lines.
zero_line_width : int
Width of the line indicating zero effect size.
remove_spines : bool, default=False
If True, removes top and right plot spines.
ax : Optional[plt.Axes], default=None
Expand Down Expand Up @@ -163,14 +175,13 @@ def forest_plot(
if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):
raise TypeError("The `selected_indices` must be a list of integers or `None`.")

# For the 'contrast_type' parameter
if not isinstance(contrast_type, str):
raise TypeError("The `contrast_type` argument must be a string.")

if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):
raise TypeError("The `xticklabels` must be a list of strings or `None`.")

raise TypeError("The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.")

# For the 'effect_size' parameter
if not isinstance(effect_size, str):
raise TypeError("The `effect_size` argument must be a string.")
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.")

if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
raise TypeError("The `contrast_labels` must be a list of strings or `None`.")
Expand All @@ -193,9 +204,6 @@ def forest_plot(
if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:
raise TypeError("`ci_line_width` must be a positive integer or float.")

if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:
raise TypeError("`zero_line_width` must be a positive integer or float.")

if not isinstance(remove_spines, bool):
raise TypeError("`remove_spines` must be a boolean value.")

Expand All @@ -211,6 +219,8 @@ def forest_plot(
if not isinstance(horizontal, bool):
raise TypeError("`horizontal` must be a boolean value.")

if (effect_size and isinstance(effect_size, str)):
ylabel = map_effect_attribute(effect_size)
# Load plot data
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)

Expand Down Expand Up @@ -252,7 +262,7 @@ def forest_plot(
if custom_palette:
if isinstance(custom_palette, dict):
violin_colors = [
custom_palette.get(c, sns.color_palette()[0]) for c in contrasts
custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels
]
elif isinstance(custom_palette, list):
violin_colors = custom_palette[: len(contrasts)]
Expand All @@ -264,12 +274,18 @@ def forest_plot(
f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette."
)
else:
violin_colors = sns.color_palette()[: len(contrasts)]
violin_colors = sns.color_palette(n_colors=len(contrasts))

violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]

for patch, color in zip(v["bodies"], violin_colors):
patch.set_facecolor(color)
patch.set_alpha(alpha_violin_plot)

if horizontal:
ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)
else:
ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)

# Flipping the axes for plotting based on 'horizontal'
for k in range(1, len(contrasts) + 1):
if horizontal:
Expand All @@ -282,19 +298,26 @@ def forest_plot(
# Adjusting labels, ticks, and limits based on 'horizontal'
if horizontal:
ax.set_yticks(range(1, len(contrasts) + 1))
ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)
ax.set_xlabel(ylabel, fontsize=fontsize)
ax.set_ylim([0.7, len(contrasts) + 0.5])
else:
ax.set_xticks(range(1, len(contrasts) + 1))
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_xlim([0.7, len(contrasts) + 0.5])

# Setting the title and adjusting spines as before
ax.set_title(title, fontsize=fontsize)
ax.set_title(title, fontsize=title_font_size)
if remove_spines:
for spine in ax.spines.values():
spine.set_visible(False)

if horizontal:
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
else:
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
# Apply additional customizations if provided
if additional_plotting_kwargs:
ax.set(**additional_plotting_kwargs)
Expand Down
79 changes: 51 additions & 28 deletions nbs/API/forest_plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,28 +133,42 @@
" \n",
" return bootstraps, differences, bcalows, bcahighs\n",
"\n",
"def map_effect_attribute(attribute_key):\n",
" # Check if the attribute key exists in the dictionary\n",
" effect_attr_map = {\n",
" \"mean_diff\": \"Mean Difference\",\n",
" \"median_diff\": \"Median Difference\",\n",
" \"cliffs_delta\": \"Cliffs Delta\",\n",
" \"cohens_d\": \"Cohens d\",\n",
" \"hedges_g\": \"Hedges g\",\n",
" \"delta_g\": \"Delta g\"\n",
" }\n",
" if attribute_key in effect_attr_map:\n",
" return effect_attr_map[attribute_key]\n",
" else:\n",
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.\") # Return a default value or message if the key is not found\n",
"\n",
"def forest_plot(\n",
" contrasts: List,\n",
" selected_indices: Optional[List] = None,\n",
" contrast_type: str = \"delta2\",\n",
" xticklabels: Optional[List] = None,\n",
" effect_size: str = \"mean_diff\",\n",
" contrast_labels: List[str] = None,\n",
" ylabel: str = \"value\",\n",
" ylabel: str = \"effect size\",\n",
" plot_elements_to_extract: Optional[List] = None,\n",
" title: str = \"ΔΔ Forest\",\n",
" custom_palette: Optional[Union[dict, list, str]] = None,\n",
" fontsize: int = 20,\n",
" fontsize: int = 12,\n",
" title_font_size: int =16,\n",
" violin_kwargs: Optional[dict] = None,\n",
" marker_size: int = 20,\n",
" ci_line_width: float = 2.5,\n",
" zero_line_width: int = 1,\n",
" desat_violin: float = 1,\n",
" remove_spines: bool = True,\n",
" ax: Optional[plt.Axes] = None,\n",
" additional_plotting_kwargs: Optional[dict] = None,\n",
" rotation_for_xlabels: int = 45,\n",
" alpha_violin_plot: float = 0.4,\n",
" alpha_violin_plot: float = 0.8,\n",
" horizontal: bool = False # New argument for horizontal orientation\n",
")-> plt.Figure:\n",
" \"\"\" \n",
Expand All @@ -167,11 +181,9 @@
" selected_indices : Optional[List], default=None\n",
" Indices of specific contrasts to plot, if not plotting all.\n",
" analysis_type : str\n",
" the type of analysis (e.g., 'delta2', 'minimeta').\n",
" xticklabels : Optional[List], default=None\n",
" Custom labels for the x-axis ticks.\n",
" the type of analysis (e.g., 'delta2', 'mini_meta').\n",
" effect_size : str\n",
" Type of effect size to plot (e.g., 'mean_diff', 'median_diff').\n",
" Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).\n",
" contrast_labels : List[str]\n",
" Labels for each contrast.\n",
" ylabel : str\n",
Expand All @@ -186,14 +198,14 @@
" Custom color palette for the plot.\n",
" fontsize : int\n",
" Font size for text elements in the plot.\n",
" title_font_size: int =16\n",
" Font size for text of plot title.\n",
" violin_kwargs : Optional[dict], default=None\n",
" Additional arguments for violin plot customization.\n",
" marker_size : int\n",
" Marker size for plotting mean differences or effect sizes.\n",
" ci_line_width : float\n",
" Width of confidence interval lines.\n",
" zero_line_width : int\n",
" Width of the line indicating zero effect size.\n",
" remove_spines : bool, default=False\n",
" If True, removes top and right plot spines.\n",
" ax : Optional[plt.Axes], default=None\n",
Expand Down Expand Up @@ -222,14 +234,13 @@
" if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):\n",
" raise TypeError(\"The `selected_indices` must be a list of integers or `None`.\")\n",
" \n",
" # For the 'contrast_type' parameter\n",
" if not isinstance(contrast_type, str):\n",
" raise TypeError(\"The `contrast_type` argument must be a string.\")\n",
" \n",
" if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):\n",
" raise TypeError(\"The `xticklabels` must be a list of strings or `None`.\")\n",
" \n",
" raise TypeError(\"The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.\")\n",
"\n",
" # For the 'effect_size' parameter\n",
" if not isinstance(effect_size, str):\n",
" raise TypeError(\"The `effect_size` argument must be a string.\")\n",
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.\")\n",
" \n",
" if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):\n",
" raise TypeError(\"The `contrast_labels` must be a list of strings or `None`.\")\n",
Expand All @@ -252,9 +263,6 @@
" if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:\n",
" raise TypeError(\"`ci_line_width` must be a positive integer or float.\")\n",
" \n",
" if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:\n",
" raise TypeError(\"`zero_line_width` must be a positive integer or float.\")\n",
" \n",
" if not isinstance(remove_spines, bool):\n",
" raise TypeError(\"`remove_spines` must be a boolean value.\")\n",
" \n",
Expand All @@ -270,6 +278,8 @@
" if not isinstance(horizontal, bool):\n",
" raise TypeError(\"`horizontal` must be a boolean value.\")\n",
"\n",
" if (effect_size and isinstance(effect_size, str)):\n",
" ylabel = map_effect_attribute(effect_size)\n",
" # Load plot data\n",
" contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)\n",
"\n",
Expand Down Expand Up @@ -311,7 +321,7 @@
" if custom_palette:\n",
" if isinstance(custom_palette, dict):\n",
" violin_colors = [\n",
" custom_palette.get(c, sns.color_palette()[0]) for c in contrasts\n",
" custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels\n",
" ]\n",
" elif isinstance(custom_palette, list):\n",
" violin_colors = custom_palette[: len(contrasts)]\n",
Expand All @@ -323,12 +333,18 @@
" f\"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette.\"\n",
" )\n",
" else:\n",
" violin_colors = sns.color_palette()[: len(contrasts)]\n",
" violin_colors = sns.color_palette(n_colors=len(contrasts))\n",
"\n",
" violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]\n",
" \n",
" for patch, color in zip(v[\"bodies\"], violin_colors):\n",
" patch.set_facecolor(color)\n",
" patch.set_alpha(alpha_violin_plot)\n",
"\n",
" if horizontal:\n",
" ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)\n",
" else:\n",
" ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)\n",
" \n",
" # Flipping the axes for plotting based on 'horizontal'\n",
" for k in range(1, len(contrasts) + 1):\n",
" if horizontal:\n",
Expand All @@ -341,19 +357,26 @@
" # Adjusting labels, ticks, and limits based on 'horizontal'\n",
" if horizontal:\n",
" ax.set_yticks(range(1, len(contrasts) + 1))\n",
" ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)\n",
" ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)\n",
" ax.set_xlabel(ylabel, fontsize=fontsize)\n",
" ax.set_ylim([0.7, len(contrasts) + 0.5])\n",
" else:\n",
" ax.set_xticks(range(1, len(contrasts) + 1))\n",
" ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)\n",
" ax.set_ylabel(ylabel, fontsize=fontsize)\n",
" ax.set_xlim([0.7, len(contrasts) + 0.5])\n",
"\n",
" # Setting the title and adjusting spines as before\n",
" ax.set_title(title, fontsize=fontsize)\n",
" ax.set_title(title, fontsize=title_font_size)\n",
" if remove_spines:\n",
" for spine in ax.spines.values():\n",
" spine.set_visible(False)\n",
"\n",
" if horizontal:\n",
" ax.spines['left'].set_visible(False)\n",
" ax.spines['right'].set_visible(False)\n",
" ax.spines['top'].set_visible(False)\n",
" else:\n",
" ax.spines['top'].set_visible(False)\n",
" ax.spines['bottom'].set_visible(False)\n",
" ax.spines['right'].set_visible(False)\n",
" # Apply additional customizations if provided\n",
" if additional_plotting_kwargs:\n",
" ax.set(**additional_plotting_kwargs)\n",
Expand Down
Loading
Loading