Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat forest plot- add tests, horizontal mode and tutorial notebook #164

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 84 additions & 69 deletions dabest/forest_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ def load_plot_data(
"cliffs_delta": "cliffs_delta",
"cohens_d": "cohens_d",
"hedges_g": "hedges_g",
"delta_g": "delta_g"
}

contrast_attr_map = {"delta2": "delta_delta", "mini_meta": "mini_meta"}
contrast_attr_map = {"delta2": "delta_delta", "mini_meta": "mini_meta_delta"}

effect_attr = effect_attr_map.get(effect_size)
contrast_attr = contrast_attr_map.get(contrast_type, "delta_delta")
contrast_attr = contrast_attr_map.get(contrast_type)

if not effect_attr:
raise ValueError(f"Invalid effect_size: {effect_size}")
Expand All @@ -46,9 +47,9 @@ def load_plot_data(
]


def extract_plot_data(contrast_plot_data, contrast_labels):
def extract_plot_data(contrast_plot_data, contrast_type):
"""Extracts bootstrap, difference, and confidence intervals based on contrast labels."""
if contrast_labels == "mini_meta":
if contrast_type == "mini_meta":
attribute_suffix = "weighted_delta"
else:
attribute_suffix = "delta_delta"
Expand All @@ -57,26 +58,25 @@ def extract_plot_data(contrast_plot_data, contrast_labels):
getattr(result, f"bootstraps_{attribute_suffix}")
for result in contrast_plot_data
]

differences = [result.difference for result in contrast_plot_data]
bcalows = [result.bca_low for result in contrast_plot_data]
bcahighs = [result.bca_high for result in contrast_plot_data]

return bootstraps, differences, bcalows, bcahighs


def forest_plot(
contrasts: List,
selected_indices: Optional[List] = None,
analysis_type: str = "delta2",
contrast_type: str = "delta2",
xticklabels: Optional[List] = None,
effect_size: str = "mean_diff",
contrast_labels: str = "delta_delta",
ylabel: str = "ΔΔ Volume (nL)",
contrast_labels: List[str] = None,
ylabel: str = "value",
plot_elements_to_extract: Optional[List] = None,
title: str = "ΔΔ Forest",
custom_palette: Optional[
Union[dict, list, str]
] = None, # Custom color palette parameter
custom_palette: Optional[Union[dict, list, str]] = None,
fontsize: int = 20,
violin_kwargs: Optional[dict] = None,
marker_size: int = 20,
Expand All @@ -87,73 +87,86 @@ def forest_plot(
additional_plotting_kwargs: Optional[dict] = None,
rotation_for_xlabels: int = 45,
alpha_violin_plot: float = 0.4,
) -> plt.Figure:
horizontal: bool = False # New argument for horizontal orientation
)-> plt.Figure:

"""
Generates a customized forest plot using contrast objects from DABEST-python package or similar.
Generates a customized forest plot using contrast objects. This function supports both horizontal and vertical orientations of the plot, as determined by the 'horizontal' parameter.

Parameters:
contrasts (List): List of contrast objects.
selected_indices (Optional[List]): Indices of contrasts to be plotted, if not all.
analysis_type (str): Type of analysis ('delta2', 'minimeta').
xticklabels (Optional[List]): Custom labels for x-axis ticks.
effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.).
contrast_labels (str): Labels for each contrast.
ylabel (str): Label for the y-axis.
plot_elements_to_extract (Optional[List]): Plot elements to be extracted for custom plotting.
title (str): Title of the plot.
ylim (Tuple[float, float]): y-axis limits.
custom_palette (Optional[Union[dict, list, str]]): Custom palette for violin plots.
fontsize (int): Font size for labels.
violin_kwargs (Optional[dict]): Additional kwargs for violin plots.
marker_size (int): Size of the markers for mean differences.
ci_line_width (float): Line width for confidence intervals.
zero_line_width (int): Width of the zero line.
remove_spines (bool): Whether to remove the plot spines.
ax (Optional[plt.Axes]): Axes object to plot on, if provided.
additional_plotting_kwargs (Optional[dict]): Additional plotting parameters.
rotation_for_xlabels (int): Rotation angle for x-axis labels.
alpha_violin_plot (float): Transparency level for violin plots.
__________
contrasts (List): List of contrast objects to be plotted.
selected_indices (Optional[List]): Specific indices of contrasts to be plotted, if not plotting all. Default is None, which means all contrasts are plotted.
contrast_type (str): Specifies the type of analysis (e.g., 'delta2', 'minimeta') for the contrasts. This determines the statistical approach used for the contrasts.
xticklabels (Optional[List]): Custom labels for the x-axis ticks. If not provided, the default is to use indices.
effect_size (str): Specifies the type of effect size to be plotted (e.g., 'mean_diff', 'median_diff'). This is crucial for interpreting the results correctly.
contrast_labels (List[str]): Labels for each contrast. These are used for labeling the plot elements and must correspond to the contrasts provided.
ylabel (str): Label for the y-axis. This should describe the data or effect size being plotted and is essential for plot interpretation.
plot_elements_to_extract (Optional[List]): Specifies which plot elements to extract for custom plotting. This allows for more detailed customization of the plot. Default is None.
title (str): The title of the plot. This should provide a concise summary of what the plot represents.
custom_palette (Optional[Union[dict, list, str]]): A custom color palette for the plot. Can be specified as a dictionary mapping contrasts to colors, a list of colors, or a string name of a seaborn or matplotlib colormap. Default is None.
fontsize (int): Font size for all text elements in the plot, including labels, title, and tick labels. This is important for ensuring the plot is readable.
violin_kwargs (Optional[dict]): Additional keyword arguments passed to the violinplot function. This allows for further customization of the violin plots. Default is None.
marker_size (int): Size of the markers used for plotting mean differences or effect sizes. This affects the visual representation of the data points.
ci_line_width (float): Line width for the confidence interval lines. This helps to visually distinguish the confidence intervals.
zero_line_width (int): Width of the line representing the zero effect size or reference line. This is important for identifying the point of no effect.
remove_spines (bool): If True, removes the top and right spines from the plot, which can make the plot look cleaner. Default is False.
ax (Optional[plt.Axes]): An existing matplotlib Axes object to plot on. If None, a new figure and axes are created. This allows for integration into larger figures.
additional_plotting_kwargs (Optional[dict]): Additional keyword arguments for customizing the plot. This provides flexibility for advanced plot customization. Default is None.
rotation_for_xlabels (int): Rotation angle (in degrees) for the x-axis labels. This can help with label readability, especially for long labels. Default is 0.
alpha_violin_plot (float): Transparency level for the violin plots. This can be adjusted to make the plot more visually appealing. Default is 1.0 (fully opaque).
horizontal (bool): If True, plots the forest plot horizontally (with effect sizes along the y-axis). Otherwise, plots vertically with effect sizes along the x-axis. This affects the orientation of the plot.

Returns:
plt.Figure: The matplotlib figure object with the plot.
_______
plt.Figure: The matplotlib figure object containing the generated plot. This object can be further modified or saved as an image file.
"""
from .plot_tools import halfviolin

# Validate inputs
if not contrasts:
raise ValueError("The `contrasts` list cannot be empty.")

if contrast_labels is not None and len(contrast_labels) != len(contrasts):
raise ValueError("`contrast_labels` must match the number of `contrasts` if provided.")

# Load plot data
contrast_plot_data = load_plot_data(contrasts, effect_size, analysis_type)
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)

# Extract data for plotting
bootstraps, differences, bcalows, bcahighs = extract_plot_data(
contrast_plot_data, contrast_labels
contrast_plot_data, contrast_type
)

# Infer the figsize based on the number of contrasts
# Adjust figure size based on orientation
all_groups_count = len(contrasts)
each_group_width_inches = 2.5 # Adjust as needed for width
base_height_inches = 4 # Base height, adjust as needed
height_inches = base_height_inches
width_inches = each_group_width_inches * all_groups_count
fig_size = (width_inches, height_inches)
if horizontal:
fig_size = (4, 2.5 * all_groups_count)
else:
fig_size = (2.5 * all_groups_count, 4)

# Create figure and axes if not provided
if ax is None:
fig, ax = plt.subplots(figsize=fig_size)
else:
fig = ax.figure

# Zero line
ax.plot([0, len(contrasts) + 1], [0, 0], "k", linewidth=zero_line_width)

# Violin plots with customizable colors
# Adjust violin plot orientation based on the 'horizontal' argument
violin_kwargs = violin_kwargs or {
"widths": 0.5,
"vert": True,
"showextrema": False,
"showmedians": False,
}
violin_kwargs["vert"] = not horizontal
v = ax.violinplot(bootstraps, **violin_kwargs)
halfviolin(v, alpha=alpha_violin_plot) # Apply halfviolin from dabest

# Adjust the halfviolin function call based on 'horizontal'
if horizontal:
half = "top"
else:
half = "right" # Assuming "right" is the default or another appropriate value

# Assuming halfviolin has been updated to accept a 'half' parameter
halfviolin(v, alpha=alpha_violin_plot, half=half)

# Handle the custom color palette
if custom_palette:
if isinstance(custom_palette, dict):
Expand All @@ -176,30 +189,32 @@ def forest_plot(
patch.set_facecolor(color)
patch.set_alpha(alpha_violin_plot)

# Effect size dot and confidence interval
# Flipping the axes for plotting based on 'horizontal'
for k in range(1, len(contrasts) + 1):
ax.plot(k, differences[k - 1], "k.", markersize=marker_size)
ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], "k", linewidth=ci_line_width)

# Custom settings
ax.set_xticks(range(1, len(contrasts) + 1))
ax.set_xticklabels(
xticklabels or range(1, len(contrasts) + 1),
rotation=rotation_for_xlabels,
fontsize=fontsize,
)
ax.set_xlim([0, len(contrasts) + 1])
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_title(title, fontsize=fontsize)
ylim = (min(bcalows) - 0.25, max(bcahighs) + 0.25)
ax.set_ylim(ylim)
if horizontal:
ax.plot(differences[k - 1], k, "k.", markersize=marker_size) # Flipped axes
ax.plot([bcalows[k - 1], bcahighs[k - 1]], [k, k], "k", linewidth=ci_line_width) # Flipped axes
else:
ax.plot(k, differences[k - 1], "k.", markersize=marker_size)
ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], "k", linewidth=ci_line_width)

# Adjusting labels, ticks, and limits based on 'horizontal'
if horizontal:
ax.set_yticks(range(1, len(contrasts) + 1))
ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_xlabel(ylabel, fontsize=fontsize)
else:
ax.set_xticks(range(1, len(contrasts) + 1))
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)

# Remove spines if requested
# Setting the title and adjusting spines as before
ax.set_title(title, fontsize=fontsize)
if remove_spines:
for spine in ax.spines.values():
spine.set_visible(False)

# Additional customization
# Apply additional customizations if provided
if additional_plotting_kwargs:
ax.set(**additional_plotting_kwargs)

Expand Down
Loading
Loading