diff --git a/.gitignore b/.gitignore index 5f112e28..99237c7e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ _proc/ .gitattributes .last_checked .gitconfig +.cursorignore *.bak *.log *~ diff --git a/dabest/_api.py b/dabest/_api.py index 7c8d0eac..a6399385 100644 --- a/dabest/_api.py +++ b/dabest/_api.py @@ -1,3 +1,5 @@ +"""Loading data and relevant groups""" + # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/load.ipynb. # %% auto 0 diff --git a/dabest/_dabest_object.py b/dabest/_dabest_object.py index 3f618a2a..ec917b03 100644 --- a/dabest/_dabest_object.py +++ b/dabest/_dabest_object.py @@ -1,3 +1,5 @@ +"""Main class for estimating statistics and generating plots.""" + # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/dabest_object.ipynb. # %% auto 0 diff --git a/dabest/_delta_objects.py b/dabest/_delta_objects.py index 30c44895..1827c1b2 100644 --- a/dabest/_delta_objects.py +++ b/dabest/_delta_objects.py @@ -1,3 +1,5 @@ +"""Auxiliary delta classes for estimating statistics and generating plots.""" + # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/delta_objects.ipynb. # %% auto 0 diff --git a/dabest/_effsize_objects.py b/dabest/_effsize_objects.py index f8bf3846..101562ed 100644 --- a/dabest/_effsize_objects.py +++ b/dabest/_effsize_objects.py @@ -1,3 +1,5 @@ +"""The auxiliary classes involved in the computations of bootstrapped effect sizes.""" + # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/effsize_objects.ipynb. # %% auto 0 @@ -971,6 +973,7 @@ def plot( contrast_ylim=None, delta2_ylim=None, swarm_side=None, + empty_circle=False, custom_palette=None, swarm_desat=0.5, halfviolin_desat=1, @@ -994,10 +997,6 @@ def plot( fig_size=None, dpi=100, ax=None, - contrast_show_es=False, - es_sf=2, - es_fontsize=10, - contrast_show_deltas=True, gridkey_rows=None, gridkey_merge_pairs=False, gridkey_show_Ns=True, @@ -1017,6 +1016,17 @@ def plot( fontsize_contrastxlabel=12, fontsize_contrastylabel=12, fontsize_delta2label=12, + #### Contrast bars and delta text and delta dots WIP #### + contrast_bars=True, + swarm_bars=True, + contrast_bars_kwargs=None, + swarm_bars_kwargs=None, + summary_bars=None, + summary_bars_kwargs=None, + delta_text=True, + delta_text_kwargs=None, + delta_dot=True, + delta_dot_kwargs=None, ): """ Creates an estimation plot for the effect size of interest. @@ -1064,6 +1074,12 @@ def plot( https://seaborn.pydata.org/generated/seaborn.cubehelix_palette.html The named colors of matplotlib can be found here: https://matplotlib.org/examples/color/named_colors.html + swarm_side: string, default None + The side on which points are swarmed for swarmplots ("center", "left", or "right"). + empty_circle: boolean, default False + Boolean value determining if empty circles will be used for plotting of + swarmplot for control groups. Color of each individual swarm is also now + dependent on the comparison group. swarm_desat : float, default 1 Decreases the saturation of the colors in the swarmplot by the desired proportion. Uses `seaborn.desaturate()` to acheive this. @@ -1159,7 +1175,39 @@ def plot( Font size for the contrast axes ylabel. fontsize_delta2label : float, default 12 Font size for the delta-delta axes ylabel. - + + + contrast_bars : boolean, default True + Whether or not to display the contrast bars. + swarm_bars : boolean, default True + Whether or not to display the swarm bars. + contrast_bars_kwargs : dict, default None + Pass relevant keyword arguments to the contrast bars. Pass any keyword arguments accepted by + matplotlib.patches.Rectangle here, as a string. If None, the following keywords are passed: + {"color": None, "alpha": 0.3} + swarm_bars_kwargs : dict, default None + Pass relevant keyword arguments to the swarm bars. Pass any keyword arguments accepted by + matplotlib.patches.Rectangle here, as a string. If None, the following keywords are passed: + {"color": None, "alpha": 0.3} + + summary_bars : list, default None + Pass a list of indices of the contrast objects to have summary bars displayed on the plot. + For example, [0,1] will show summary bars for the first two contrast objects. + summary_bars_kwargs: dict, default None + If None, the following keywords are passed: {"color": None, "alpha": 0.15} + delta_text : boolean, default True + Whether or not to display the text deltas. + delta_text_kwargs : dict, default None + Pass relevant keyword arguments to the delta text. Pass any keyword arguments accepted by + matplotlib.text.Text here, as a string. If None, the following keywords are passed: + {"color": None, "alpha": 1, "fontsize": 10, "ha": 'center', "va": 'center', "rotation": 0, + "x_location": 'right', "x_coordinates": None, "y_coordinates": None} + Use "x_coordinates" and "y_coordinates" if you would like to specify the text locations manually. + delta_dot : boolean, default True + Whether or not to display the delta dots on paired or repeated measure plots. + delta_dot_kwargs : dict, default None + Pass relevant keyword arguments. If None, the following keywords are passed: + {"marker": "^", "alpha": 0.5, "zorder": 2, "size": 3, "side": "right"} Returns ------- @@ -1180,7 +1228,7 @@ def plot( if hasattr(self, "results") is False: self.__pre_calc() - if self.__delta2: + if self.__delta2 and not empty_circle: color_col = self.__x2 # if self.__proportional: diff --git a/dabest/_modidx.py b/dabest/_modidx.py index 14bfa3da..24356d58 100644 --- a/dabest/_modidx.py +++ b/dabest/_modidx.py @@ -65,11 +65,30 @@ 'dabest/forest_plot.py'), 'dabest.forest_plot.forest_plot': ('API/forest_plot.html#forest_plot', 'dabest/forest_plot.py'), 'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py')}, - 'dabest.misc_tools': { 'dabest.misc_tools.get_varname': ('API/misc_tools.html#get_varname', 'dabest/misc_tools.py'), + 'dabest.misc_tools': { 'dabest.misc_tools.Cumming_Plot_Aesthetic_Adjustments': ( 'API/misc_tools.html#cumming_plot_aesthetic_adjustments', + 'dabest/misc_tools.py'), + 'dabest.misc_tools.Gardner_Altman_Plot_Aesthetic_Adjustments': ( 'API/misc_tools.html#gardner_altman_plot_aesthetic_adjustments', + 'dabest/misc_tools.py'), + 'dabest.misc_tools.General_Plot_Aesthetic_Adjustments': ( 'API/misc_tools.html#general_plot_aesthetic_adjustments', + 'dabest/misc_tools.py'), + 'dabest.misc_tools.add_counts_to_ticks': ( 'API/misc_tools.html#add_counts_to_ticks', + 'dabest/misc_tools.py'), + 'dabest.misc_tools.extract_contrast_plotting_ticks': ( 'API/misc_tools.html#extract_contrast_plotting_ticks', + 'dabest/misc_tools.py'), + 'dabest.misc_tools.get_color_palette': ('API/misc_tools.html#get_color_palette', 'dabest/misc_tools.py'), + 'dabest.misc_tools.get_kwargs': ('API/misc_tools.html#get_kwargs', 'dabest/misc_tools.py'), + 'dabest.misc_tools.get_params': ('API/misc_tools.html#get_params', 'dabest/misc_tools.py'), + 'dabest.misc_tools.get_plot_groups': ('API/misc_tools.html#get_plot_groups', 'dabest/misc_tools.py'), + 'dabest.misc_tools.get_varname': ('API/misc_tools.html#get_varname', 'dabest/misc_tools.py'), + 'dabest.misc_tools.initialize_fig': ('API/misc_tools.html#initialize_fig', 'dabest/misc_tools.py'), 'dabest.misc_tools.merge_two_dicts': ('API/misc_tools.html#merge_two_dicts', 'dabest/misc_tools.py'), 'dabest.misc_tools.print_greeting': ('API/misc_tools.html#print_greeting', 'dabest/misc_tools.py'), + 'dabest.misc_tools.set_xaxis_ticks_and_lims': ( 'API/misc_tools.html#set_xaxis_ticks_and_lims', + 'dabest/misc_tools.py'), + 'dabest.misc_tools.show_legend': ('API/misc_tools.html#show_legend', 'dabest/misc_tools.py'), 'dabest.misc_tools.unpack_and_add': ('API/misc_tools.html#unpack_and_add', 'dabest/misc_tools.py')}, - 'dabest.plot_tools': { 'dabest.plot_tools.SwarmPlot': ('API/plot_tools.html#swarmplot', 'dabest/plot_tools.py'), + 'dabest.plot_tools': { 'dabest.plot_tools.DeltaDotsPlotter': ('API/plot_tools.html#deltadotsplotter', 'dabest/plot_tools.py'), + 'dabest.plot_tools.SwarmPlot': ('API/plot_tools.html#swarmplot', 'dabest/plot_tools.py'), 'dabest.plot_tools.SwarmPlot.__init__': ( 'API/plot_tools.html#swarmplot.__init__', 'dabest/plot_tools.py'), 'dabest.plot_tools.SwarmPlot._adjust_gutter_points': ( 'API/plot_tools.html#swarmplot._adjust_gutter_points', @@ -82,14 +101,30 @@ 'dabest/plot_tools.py'), 'dabest.plot_tools.SwarmPlot._swarm': ('API/plot_tools.html#swarmplot._swarm', 'dabest/plot_tools.py'), 'dabest.plot_tools.SwarmPlot.plot': ('API/plot_tools.html#swarmplot.plot', 'dabest/plot_tools.py'), + 'dabest.plot_tools.barplotter': ('API/plot_tools.html#barplotter', 'dabest/plot_tools.py'), 'dabest.plot_tools.check_data_matches_labels': ( 'API/plot_tools.html#check_data_matches_labels', 'dabest/plot_tools.py'), + 'dabest.plot_tools.contrast_bars_plotter': ( 'API/plot_tools.html#contrast_bars_plotter', + 'dabest/plot_tools.py'), + 'dabest.plot_tools.delta_text_plotter': ( 'API/plot_tools.html#delta_text_plotter', + 'dabest/plot_tools.py'), + 'dabest.plot_tools.effect_size_curve_plotter': ( 'API/plot_tools.html#effect_size_curve_plotter', + 'dabest/plot_tools.py'), 'dabest.plot_tools.error_bar': ('API/plot_tools.html#error_bar', 'dabest/plot_tools.py'), 'dabest.plot_tools.get_swarm_spans': ('API/plot_tools.html#get_swarm_spans', 'dabest/plot_tools.py'), + 'dabest.plot_tools.grid_key_WIP': ('API/plot_tools.html#grid_key_wip', 'dabest/plot_tools.py'), 'dabest.plot_tools.halfviolin': ('API/plot_tools.html#halfviolin', 'dabest/plot_tools.py'), 'dabest.plot_tools.normalize_dict': ('API/plot_tools.html#normalize_dict', 'dabest/plot_tools.py'), + 'dabest.plot_tools.plot_minimeta_or_deltadelta_violins': ( 'API/plot_tools.html#plot_minimeta_or_deltadelta_violins', + 'dabest/plot_tools.py'), 'dabest.plot_tools.sankeydiag': ('API/plot_tools.html#sankeydiag', 'dabest/plot_tools.py'), 'dabest.plot_tools.single_sankey': ('API/plot_tools.html#single_sankey', 'dabest/plot_tools.py'), + 'dabest.plot_tools.slopegraph_plotter': ( 'API/plot_tools.html#slopegraph_plotter', + 'dabest/plot_tools.py'), + 'dabest.plot_tools.summary_bars_plotter': ( 'API/plot_tools.html#summary_bars_plotter', + 'dabest/plot_tools.py'), + 'dabest.plot_tools.swarm_bars_plotter': ( 'API/plot_tools.html#swarm_bars_plotter', + 'dabest/plot_tools.py'), 'dabest.plot_tools.swarmplot': ('API/plot_tools.html#swarmplot', 'dabest/plot_tools.py'), 'dabest.plot_tools.width_determine': ('API/plot_tools.html#width_determine', 'dabest/plot_tools.py')}, 'dabest.plotter': {'dabest.plotter.effectsize_df_plotter': ('API/plotter.html#effectsize_df_plotter', 'dabest/plotter.py')}}} diff --git a/dabest/_stats_tools/confint_1group.py b/dabest/_stats_tools/confint_1group.py index a9b0beb1..744a7142 100644 --- a/dabest/_stats_tools/confint_1group.py +++ b/dabest/_stats_tools/confint_1group.py @@ -1,3 +1,5 @@ +"""A range of functions to compute bootstraps for a single sample.""" + # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/API/confint_1group.ipynb. # %% auto 0 diff --git a/dabest/_stats_tools/confint_2group_diff.py b/dabest/_stats_tools/confint_2group_diff.py index 3b07eb96..c599e178 100644 --- a/dabest/_stats_tools/confint_2group_diff.py +++ b/dabest/_stats_tools/confint_2group_diff.py @@ -1,3 +1,5 @@ +"""A range of functions to compute bootstraps for the mean difference""" + # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/API/confint_2group_diff.ipynb. # %% auto 0 diff --git a/dabest/_stats_tools/effsize.py b/dabest/_stats_tools/effsize.py index 32f965b1..f5a0d4fc 100644 --- a/dabest/_stats_tools/effsize.py +++ b/dabest/_stats_tools/effsize.py @@ -1,3 +1,5 @@ +"""A range of functions to compute various effect sizes.""" + # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/API/effsize.ipynb. # %% ../../nbs/API/effsize.ipynb 4 diff --git a/dabest/forest_plot.py b/dabest/forest_plot.py index 7d29464f..583ece0c 100644 --- a/dabest/forest_plot.py +++ b/dabest/forest_plot.py @@ -1,3 +1,5 @@ +"""Creating forest plots from contrast objects.""" + # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/forest_plot.ipynb. # %% auto 0 diff --git a/dabest/misc_tools.py b/dabest/misc_tools.py index 7c5b2020..8c7d0e96 100644 --- a/dabest/misc_tools.py +++ b/dabest/misc_tools.py @@ -1,11 +1,21 @@ +"""Convenience functions that don't directly deal with plotting or bootstrap computations are placed here.""" + # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/misc_tools.ipynb. # %% auto 0 -__all__ = ['merge_two_dicts', 'unpack_and_add', 'print_greeting', 'get_varname'] +__all__ = ['merge_two_dicts', 'unpack_and_add', 'print_greeting', 'get_varname', 'get_params', 'get_kwargs', 'get_color_palette', + 'initialize_fig', 'get_plot_groups', 'add_counts_to_ticks', 'extract_contrast_plotting_ticks', + 'set_xaxis_ticks_and_lims', 'show_legend', 'Gardner_Altman_Plot_Aesthetic_Adjustments', + 'Cumming_Plot_Aesthetic_Adjustments', 'General_Plot_Aesthetic_Adjustments'] # %% ../nbs/API/misc_tools.ipynb 4 import datetime as dt +import numpy as np from numpy import repeat +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +import matplotlib # %% ../nbs/API/misc_tools.ipynb 5 def merge_two_dicts( @@ -68,3 +78,1022 @@ def get_varname(obj): if len(matching_vars) > 0: return matching_vars[0] return "" + +def get_params(effectsize_df, plot_kwargs): + """ + Parameters + ---------- + effectsize_df : object (Dataframe) + A `dabest` EffectSizeDataFrame object. + plot_kwargs : dict + Kwargs passed to the plot function. + """ + dabest_obj = effectsize_df.dabest_obj + plot_data = effectsize_df._plot_data + xvar = effectsize_df.xvar + yvar = effectsize_df.yvar + is_paired = effectsize_df.is_paired + delta2 = effectsize_df.delta2 + mini_meta = effectsize_df.mini_meta + effect_size = effectsize_df.effect_size + proportional = effectsize_df.proportional + all_plot_groups = dabest_obj._all_plot_groups + idx = dabest_obj.idx + + if effect_size not in ["mean_diff", "delta_g"] or not delta2: + show_delta2 = False + else: + show_delta2 = plot_kwargs["show_delta2"] + + if effect_size != "mean_diff" or not mini_meta: + show_mini_meta = False + else: + show_mini_meta = plot_kwargs["show_mini_meta"] + + if show_delta2 and show_mini_meta: raise ValueError("`show_delta2` and `show_mini_meta` cannot be True at the same time.") + + # Disable Gardner-Altman plotting if any of the idxs comprise of more than + # two groups or if it is a delta-delta plot. + float_contrast = plot_kwargs["float_contrast"] + effect_size_type = effectsize_df.effect_size + if len(idx) > 1 or len(idx[0]) > 2: + float_contrast = False + + if effect_size_type in ["cliffs_delta"]: + float_contrast = False + + if show_delta2 or show_mini_meta: + float_contrast = False + + if not is_paired: + show_pairs = False + else: + show_pairs = plot_kwargs["show_pairs"] + + # Group summaries + group_summaries = plot_kwargs["group_summaries"] + if group_summaries is None: + group_summaries = "mean_sd" + + # Error bar color + err_color = plot_kwargs["err_color"] + if err_color is None: + err_color = "black" + + return (dabest_obj, plot_data, xvar, yvar, is_paired, effect_size, proportional, all_plot_groups, idx, + show_delta2, show_mini_meta, float_contrast, show_pairs, effect_size_type, group_summaries, err_color) + +def get_kwargs(plot_kwargs, ytick_color): + """ + Parameters + ---------- + plot_kwargs : dict + Kwargs passed to the plot function. + ytick_color : str + Color of the yticks. + """ + from .misc_tools import merge_two_dicts + + # Swarmplot kwargs + default_swarmplot_kwargs = {"size": plot_kwargs["raw_marker_size"]} + if plot_kwargs["swarmplot_kwargs"] is None: + swarmplot_kwargs = default_swarmplot_kwargs + else: + swarmplot_kwargs = merge_two_dicts( + default_swarmplot_kwargs, plot_kwargs["swarmplot_kwargs"] + ) + + # Barplot kwargs + default_barplot_kwargs = {"estimator": np.mean, "errorbar": plot_kwargs["ci"]} + if plot_kwargs["barplot_kwargs"] is None: + barplot_kwargs = default_barplot_kwargs + else: + barplot_kwargs = merge_two_dicts( + default_barplot_kwargs, plot_kwargs["barplot_kwargs"] + ) + + # Sankey Diagram kwargs + default_sankey_kwargs = { + "width": 0.4, + "align": "center", + "sankey": True, + "flow": True, + "alpha": 0.4, + "rightColor": False, + "bar_width": 0.2, + } + if plot_kwargs["sankey_kwargs"] is None: + sankey_kwargs = default_sankey_kwargs + else: + sankey_kwargs = merge_two_dicts( + default_sankey_kwargs, plot_kwargs["sankey_kwargs"] + ) + + # Violinplot kwargs. + default_violinplot_kwargs = { + "widths": 0.5, + "vert": True, + "showextrema": False, + "showmedians": False, + } + if plot_kwargs["violinplot_kwargs"] is None: + violinplot_kwargs = default_violinplot_kwargs + else: + violinplot_kwargs = merge_two_dicts( + default_violinplot_kwargs, plot_kwargs["violinplot_kwargs"] + ) + + # Slopegraph kwargs. + default_slopegraph_kwargs = {"linewidth": 1, "alpha": 0.5} + if plot_kwargs["slopegraph_kwargs"] is None: + slopegraph_kwargs = default_slopegraph_kwargs + else: + slopegraph_kwargs = merge_two_dicts( + default_slopegraph_kwargs, plot_kwargs["slopegraph_kwargs"] + ) + + # Zero reference-line kwargs. + default_reflines_kwargs = { + "linestyle": "solid", + "linewidth": 0.75, + "zorder": 2, + "color": ytick_color, + } + if plot_kwargs["reflines_kwargs"] is None: + reflines_kwargs = default_reflines_kwargs + else: + reflines_kwargs = merge_two_dicts( + default_reflines_kwargs, plot_kwargs["reflines_kwargs"] + ) + + # Legend kwargs. + default_legend_kwargs = {"loc": "upper left", "frameon": False} + if plot_kwargs["legend_kwargs"] is None: + legend_kwargs = default_legend_kwargs + else: + legend_kwargs = merge_two_dicts( + default_legend_kwargs, plot_kwargs["legend_kwargs"] + ) + + # Group summaries kwargs. + gs_default = {"mean_sd", "median_quartiles", None} + if plot_kwargs["group_summaries"] not in gs_default: + raise ValueError( + "group_summaries must be one of" " these: {}.".format(gs_default) + ) + + default_group_summary_kwargs = {"zorder": 3, "lw": 2, "alpha": 1} + if plot_kwargs["group_summary_kwargs"] is None: + group_summary_kwargs = default_group_summary_kwargs + else: + group_summary_kwargs = merge_two_dicts( + default_group_summary_kwargs, plot_kwargs["group_summary_kwargs"] + ) + + # Redraw axes kwargs. + redraw_axes_kwargs = { + "colors": ytick_color, + "facecolors": ytick_color, + "lw": 1, + "zorder": 10, + "clip_on": False, + } + + # Delta dots kwargs. + default_delta_dot_kwargs = {"marker": "^", "alpha": 0.5, "zorder": 2, "size": 3, "side": "right"} + if plot_kwargs["delta_dot_kwargs"] is None: + delta_dot_kwargs = default_delta_dot_kwargs + else: + delta_dot_kwargs = merge_two_dicts(default_delta_dot_kwargs, plot_kwargs["delta_dot_kwargs"]) + + # Delta text kwargs. + default_delta_text_kwargs = {"color": None, "alpha": 1, "fontsize": 10, "ha": 'center', "va": 'center', "rotation": 0, "x_location": 'right', "x_coordinates": None, "y_coordinates": None} + if plot_kwargs["delta_text_kwargs"] is None: + delta_text_kwargs = default_delta_text_kwargs + else: + delta_text_kwargs = merge_two_dicts(default_delta_text_kwargs, plot_kwargs["delta_text_kwargs"]) + + # Summary bars kwargs. + default_summary_bars_kwargs = {"color": None, "alpha": 0.15} + if plot_kwargs["summary_bars_kwargs"] is None: + summary_bars_kwargs = default_summary_bars_kwargs + else: + summary_bars_kwargs = merge_two_dicts(default_summary_bars_kwargs, plot_kwargs["summary_bars_kwargs"]) + + # Swarm bars kwargs. + default_swarm_bars_kwargs = {"color": None, "alpha": 0.3} + if plot_kwargs["swarm_bars_kwargs"] is None: + swarm_bars_kwargs = default_swarm_bars_kwargs + else: + swarm_bars_kwargs = merge_two_dicts(default_swarm_bars_kwargs, plot_kwargs["swarm_bars_kwargs"]) + + # Contrast bars kwargs. + default_contrast_bars_kwargs = {"color": None, "alpha": 0.3} + if plot_kwargs["contrast_bars_kwargs"] is None: + contrast_bars_kwargs = default_contrast_bars_kwargs + else: + contrast_bars_kwargs = merge_two_dicts(default_contrast_bars_kwargs, plot_kwargs["contrast_bars_kwargs"]) + + return (swarmplot_kwargs, barplot_kwargs, sankey_kwargs, violinplot_kwargs, slopegraph_kwargs, + reflines_kwargs, legend_kwargs, group_summary_kwargs, redraw_axes_kwargs, delta_dot_kwargs, + delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs) + + +def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx): + + # Create color palette that will be shared across subplots. + color_col = plot_kwargs["color_col"] + if color_col is None: + color_groups = pd.unique(plot_data[xvar]) + bootstraps_color_by_group = True + else: + if color_col not in plot_data.columns: + raise KeyError("``{}`` is not a column in the data.".format(color_col)) + color_groups = pd.unique(plot_data[color_col]) + bootstraps_color_by_group = False + if show_pairs: + bootstraps_color_by_group = False + # Handle the color palette. + filled = True + empty_circle = plot_kwargs["empty_circle"] + color_by_subgroups = ( + True if empty_circle else False + ) # boolean flag to determine if colour is being grouped by subgroup or the default + if empty_circle: + # Handling color_by_subgroups + # For now, color_by_subgroups can only be True for multi-2-group and 2-group comparison + if isinstance(idx[0], str): + if len(idx) > 2: + color_by_subgroups = False + else: + for group_i in idx: + if len(group_i) > 2: + color_by_subgroups = False + + # filled is now a list, which determines the which group in idx has their dots filled for the swarmplot + filled = [] + for i in range(len(idx)): + filled.append(False) + filled.extend([True] * (len(idx[i]) - 1)) + + names = color_groups if not color_by_subgroups else idx + n_groups = len(color_groups) + custom_pal = plot_kwargs["custom_palette"] + swarm_desat = plot_kwargs["swarm_desat"] + bar_desat = plot_kwargs["bar_desat"] + contrast_desat = plot_kwargs["halfviolin_desat"] + + if custom_pal is None: + unsat_colors = sns.color_palette(n_colors=n_groups) + if empty_circle and not color_by_subgroups: + unsat_colors = [sns.color_palette("gray")[3]] + unsat_colors + else: + if isinstance(custom_pal, dict): + groups_in_palette = { + k: v for k, v in custom_pal.items() if k in color_groups + } + + names = groups_in_palette.keys() + unsat_colors = groups_in_palette.values() + + elif isinstance(custom_pal, list): + unsat_colors = custom_pal[0:n_groups] + + elif isinstance(custom_pal, str): + # check it is in the list of matplotlib palettes. + if custom_pal in plt.colormaps(): + unsat_colors = sns.color_palette(custom_pal, n_groups) + else: + err1 = "The specified `custom_palette` {}".format(custom_pal) + err2 = " is not a matplotlib palette. Please check." + raise ValueError(err1 + err2) + + if custom_pal is None and color_col is None: + swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors] + contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors] + bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors] + if color_by_subgroups: + plot_palette_raw = dict() + plot_palette_contrast = dict() + # plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots + plot_palette_bar = None + for i in range(len(idx)): + for names_i in idx[i]: + plot_palette_raw[names_i] = swarm_colors[i] + plot_palette_contrast[names_i] = contrast_colors[i] + else: + plot_palette_raw = dict(zip(names.categories, swarm_colors)) + plot_palette_contrast = dict(zip(names.categories, contrast_colors)) + plot_palette_bar = dict(zip(names.categories, bar_color)) + + # For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors + # default color palette will be set to "hls" + plot_palette_sankey = None + + else: + swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors] + contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors] + bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors] + if color_by_subgroups: + plot_palette_raw = dict() + plot_palette_contrast = dict() + # plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots + plot_palette_bar = None + for i in range(len(idx)): + for names_i in idx[i]: + plot_palette_raw[names_i] = swarm_colors[i] + plot_palette_contrast[names_i] = contrast_colors[i] + else: + plot_palette_raw = dict(zip(names, swarm_colors)) + plot_palette_contrast = dict(zip(names, contrast_colors)) + plot_palette_bar = dict(zip(names, bar_color)) + + plot_palette_sankey = custom_pal + + return (color_col, bootstraps_color_by_group, n_groups, filled, swarm_colors, plot_palette_raw, + bar_color, plot_palette_bar, plot_palette_contrast, plot_palette_sankey) + +def initialize_fig(plot_kwargs, dabest_obj, show_delta2, show_mini_meta, is_paired, show_pairs, proportional, + float_contrast): + # Params + fig_size = plot_kwargs["fig_size"] + face_color = plot_kwargs["face_color"] + if plot_kwargs["face_color"] is None: + face_color = "white" + + if fig_size is None: + all_groups_count = np.sum([len(i) for i in dabest_obj.idx]) + # Increase the width for delta-delta graph + if show_delta2 or show_mini_meta: + all_groups_count += 2 + if is_paired and show_pairs and proportional is False: + frac = 0.8 + else: + frac = 1 + if float_contrast: + height_inches = 4 + each_group_width_inches = 2.5 * frac + else: + height_inches = 6 + each_group_width_inches = 1.5 * frac + + width_inches = each_group_width_inches * all_groups_count + fig_size = (width_inches, height_inches) + + init_fig_kwargs = dict(figsize=fig_size, dpi=plot_kwargs["dpi"], tight_layout=True) + width_ratios_ga = [2.5, 1] + + h_space_cummings = 0.3 if plot_kwargs["gridkey_rows"] == None else 0.1 ##### GRIDKEY WIP addition + + if plot_kwargs["ax"] is not None: + # New in v0.2.6. + # Use inset axes to create the estimation plot inside a single axes. + # Author: Adam L Nekimken. (PR #73) + rawdata_axes = plot_kwargs["ax"] + ax_position = rawdata_axes.get_position() # [[x0, y0], [x1, y1]] + + fig = rawdata_axes.get_figure() + fig.patch.set_facecolor(face_color) + + if float_contrast: + axins = rawdata_axes.inset_axes( + [1, 0, width_ratios_ga[1] / width_ratios_ga[0], 1] + ) + rawdata_axes.set_position( # [l, b, w, h] + [ + ax_position.x0, + ax_position.y0, + (ax_position.x1 - ax_position.x0) + * (width_ratios_ga[0] / sum(width_ratios_ga)), + (ax_position.y1 - ax_position.y0), + ] + ) + + contrast_axes = axins + else: + axins = rawdata_axes.inset_axes([0, -1 - h_space_cummings, 1, 1]) + plot_height = (ax_position.y1 - ax_position.y0) / (2 + h_space_cummings) + rawdata_axes.set_position( + [ + ax_position.x0, + ax_position.y0 + (1 + h_space_cummings) * plot_height, + (ax_position.x1 - ax_position.x0), + plot_height, + ] + ) + + contrast_axes = axins + rawdata_axes.contrast_axes = axins + + else: + # Here, we hardcode some figure parameters. + if float_contrast: + fig, axx = plt.subplots( + ncols=2, + gridspec_kw={"width_ratios": width_ratios_ga, "wspace": 0}, + **init_fig_kwargs + ) + fig.patch.set_facecolor(face_color) + + else: + fig, axx = plt.subplots( + nrows=2, gridspec_kw={"hspace": h_space_cummings}, **init_fig_kwargs + ) + fig.patch.set_facecolor(face_color) + + # Title + title = plot_kwargs["title"] + fontsize_title = plot_kwargs["fontsize_title"] + if title is not None: + fig.suptitle(title, fontsize=fontsize_title) + rawdata_axes = axx[0] + contrast_axes = axx[1] + rawdata_axes.set_frame_on(False) + contrast_axes.set_frame_on(False) + + swarm_ylim = plot_kwargs["swarm_ylim"] + if swarm_ylim is not None: + rawdata_axes.set_ylim(swarm_ylim) + + return fig, rawdata_axes, contrast_axes, swarm_ylim + +def get_plot_groups(is_paired, idx, proportional, all_plot_groups): + + if is_paired == "baseline": + idx_pairs = [ + (control, test) + for i in idx + for control, test in zip([i[0]] * (len(i) - 1), i[1:]) + ] + temp_idx = idx if not proportional else idx_pairs + else: + idx_pairs = [ + (control, test) for i in idx for control, test in zip(i[:-1], i[1:]) + ] + temp_idx = idx if not proportional else idx_pairs + + # Determine temp_all_plot_groups based on proportional condition + plot_groups = [item for i in temp_idx for item in i] + temp_all_plot_groups = all_plot_groups if not proportional else plot_groups + + return temp_idx, temp_all_plot_groups + + +def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs): + # Add the counts to the rawdata axes xticks. + counts = plot_data.groupby(xvar).count()[yvar] + ticks_with_counts = [] + ticks_loc = rawdata_axes.get_xticks() + rawdata_axes.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(ticks_loc)) + def lookup_value(text, counts): + try: + return str(counts.loc[text]) + except KeyError: + try: + numeric_key = pd.to_numeric(text, errors='coerce') + if pd.notnull(numeric_key): + return str(counts.loc[numeric_key]) + else: + raise ValueError + except (ValueError, KeyError): + print(f"Key '{text}' not found in counts.") + return "N/A" + for xticklab in rawdata_axes.xaxis.get_ticklabels(): + t = xticklab.get_text() + # Extract the text after the last newline, if present + if t.rfind("\n") != -1: + te = t[t.rfind("\n") + len("\n"):] + value = lookup_value(te, counts) + te = t + else: + te = t + value = lookup_value(te, counts) + + # Append the modified tick label with the count to the list + ticks_with_counts.append(f"{te}\nN = {value}") + + + if plot_kwargs["fontsize_rawxlabel"] is not None: + fontsize_rawxlabel = plot_kwargs["fontsize_rawxlabel"] + rawdata_axes.set_xticklabels(ticks_with_counts, fontsize=fontsize_rawxlabel) + + +def extract_contrast_plotting_ticks(is_paired, show_pairs, two_col_sankey, plot_groups, idx, sankey_control_group): + + # Take note of where the `control` groups are. + ticks_to_skip_contrast = None + ticks_to_start_twocol_sankey = None + if is_paired == "baseline" and show_pairs: + if two_col_sankey: + ticks_to_skip = [] + ticks_to_plot = np.arange(0, len(plot_groups) / 2).tolist() + ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist() + ticks_to_start_twocol_sankey.pop() + ticks_to_start_twocol_sankey.insert(0, 0) + else: + # ticks_to_skip = np.arange(0, len(temp_all_plot_groups), 2).tolist() + # ticks_to_plot = np.arange(1, len(temp_all_plot_groups), 2).tolist() + ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist() + ticks_to_skip.insert(0, 0) + # Then obtain the ticks where we have to plot the effect sizes. + ticks_to_plot = [ + t for t in range(0, len(plot_groups)) if t not in ticks_to_skip + ] + ticks_to_skip_contrast = np.cumsum([(len(t)) for t in idx])[:-1].tolist() + ticks_to_skip_contrast.insert(0, 0) + else: + if two_col_sankey: + ticks_to_skip = [len(sankey_control_group)] + # Then obtain the ticks where we have to plot the effect sizes. + ticks_to_plot = [ + t for t in range(0, len(plot_groups)) if t not in ticks_to_skip + ] + ticks_to_skip = [] + ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist() + ticks_to_start_twocol_sankey.pop() + ticks_to_start_twocol_sankey.insert(0, 0) + else: + ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist() + ticks_to_skip.insert(0, 0) + # Then obtain the ticks where we have to plot the effect sizes. + ticks_to_plot = [ + t for t in range(0, len(plot_groups)) if t not in ticks_to_skip + ] + + return ticks_to_skip, ticks_to_plot, ticks_to_skip_contrast, ticks_to_start_twocol_sankey + +def set_xaxis_ticks_and_lims(show_delta2, show_mini_meta, rawdata_axes, contrast_axes, show_pairs, float_contrast, + ticks_to_skip, contrast_xtick_labels, plot_kwargs): + + if show_delta2 is False and show_mini_meta is False: + contrast_axes.set_xticks(rawdata_axes.get_xticks()) + else: + temp = rawdata_axes.get_xticks() + temp = np.append(temp, [max(temp) + 1, max(temp) + 2]) + contrast_axes.set_xticks(temp) + + if show_pairs: + max_x = contrast_axes.get_xlim()[1] + rawdata_axes.set_xlim(-0.375, max_x) + + if float_contrast: + contrast_axes.set_xlim(0.5, 1.5) + elif show_delta2 or show_mini_meta: + # Increase the xlim of raw data by 2 + temp = rawdata_axes.get_xlim() + if show_pairs: + rawdata_axes.set_xlim(temp[0], temp[1] + 0.25) + else: + rawdata_axes.set_xlim(temp[0], temp[1] + 2) + contrast_axes.set_xlim(rawdata_axes.get_xlim()) + else: + contrast_axes.set_xlim(rawdata_axes.get_xlim()) + + # Properly label the contrast ticks. + for t in ticks_to_skip: + contrast_xtick_labels.insert(t, "") + + if plot_kwargs["fontsize_contrastxlabel"] is not None: + fontsize_contrastxlabel = plot_kwargs["fontsize_contrastxlabel"] + + contrast_axes.set_xticklabels( + contrast_xtick_labels, fontsize=fontsize_contrastxlabel + ) + + +def show_legend(legend_labels, legend_handles, rawdata_axes, contrast_axes, float_contrast, show_pairs, legend_kwargs): + + legend_labels_unique = np.unique(legend_labels) + unique_idx = np.unique(legend_labels, return_index=True)[1] + legend_handles_unique = ( + pd.Series(legend_handles, dtype="object").loc[unique_idx] + ).tolist() + + if len(legend_handles_unique) > 0: + if float_contrast: + axes_with_legend = contrast_axes + if show_pairs: + bta = (2.00, 1.02) + else: + bta = (1.5, 1.02) + else: + axes_with_legend = rawdata_axes + if show_pairs: + bta = (1.02, 1.0) + else: + bta = (1.0, 1.0) + leg = axes_with_legend.legend( + legend_handles_unique, + legend_labels_unique, + bbox_to_anchor=bta, + **legend_kwargs + ) + if show_pairs: + for line in leg.get_lines(): + line.set_linewidth(3.0) + +def Gardner_Altman_Plot_Aesthetic_Adjustments(effect_size_type, plot_data, xvar, yvar, current_control, current_group, + rawdata_axes, contrast_axes, results, current_effsize, is_paired, one_sankey, + reflines_kwargs, redraw_axes_kwargs, swarm_ylim, og_xlim_raw, og_ylim_raw): + from ._stats_tools.effsize import ( + _compute_standardizers, + _compute_hedges_correction_factor, + ) + # Normalize ylims and despine the floating contrast axes. + # Check that the effect size is within the swarm ylims. + if effect_size_type in ["mean_diff", "cohens_d", "hedges_g", "cohens_h"]: + control_group_summary = ( + plot_data.groupby(xvar) + .mean(numeric_only=True) + .loc[current_control, yvar] + ) + test_group_summary = ( + plot_data.groupby(xvar).mean(numeric_only=True).loc[current_group, yvar] + ) + elif effect_size_type == "median_diff": + control_group_summary = ( + plot_data.groupby(xvar).median().loc[current_control, yvar] + ) + test_group_summary = ( + plot_data.groupby(xvar).median().loc[current_group, yvar] + ) + + if swarm_ylim is None: + swarm_ylim = rawdata_axes.get_ylim() + + _, contrast_xlim_max = contrast_axes.get_xlim() + + difference = float(results.difference[0]) + + if effect_size_type in ["mean_diff", "median_diff"]: + # Align 0 of contrast_axes to reference group mean of rawdata_axes. + # If the effect size is positive, shift the contrast axis up. + rawdata_ylims = np.array(rawdata_axes.get_ylim()) + if current_effsize > 0: + rightmin, rightmax = rawdata_ylims - current_effsize + # If the effect size is negative, shift the contrast axis down. + elif current_effsize < 0: + rightmin, rightmax = rawdata_ylims + current_effsize + else: + rightmin, rightmax = rawdata_ylims + + contrast_axes.set_ylim(rightmin, rightmax) + + og_ylim_contrast = rawdata_axes.get_ylim() - np.array(control_group_summary) + + contrast_axes.set_ylim(og_ylim_contrast) + contrast_axes.set_xlim(contrast_xlim_max - 1, contrast_xlim_max) + + elif effect_size_type in ["cohens_d", "hedges_g", "cohens_h"]: + if is_paired: + which_std = 1 + else: + which_std = 0 + temp_control = plot_data[plot_data[xvar] == current_control][yvar] + temp_test = plot_data[plot_data[xvar] == current_group][yvar] + + stds = _compute_standardizers(temp_control, temp_test) + if is_paired: + pooled_sd = stds[1] + else: + pooled_sd = stds[0] + + if effect_size_type == "hedges_g": + gby_count = plot_data.groupby(xvar).count() + len_control = gby_count.loc[current_control, yvar] + len_test = gby_count.loc[current_group, yvar] + + hg_correction_factor = _compute_hedges_correction_factor( + len_control, len_test + ) + + ylim_scale_factor = pooled_sd / hg_correction_factor + + elif effect_size_type == "cohens_h": + ylim_scale_factor = ( + np.mean(temp_test) - np.mean(temp_control) + ) / difference + + else: + ylim_scale_factor = pooled_sd + + scaled_ylim = ( + (rawdata_axes.get_ylim() - control_group_summary) / ylim_scale_factor + ).tolist() + + contrast_axes.set_ylim(scaled_ylim) + og_ylim_contrast = scaled_ylim + + contrast_axes.set_xlim(contrast_xlim_max - 1, contrast_xlim_max) + + if one_sankey is None: + # Draw summary lines for control and test groups.. + for jj, axx in enumerate([rawdata_axes, contrast_axes]): + # Draw effect size line. + if jj == 0: + ref = control_group_summary + diff = test_group_summary + effsize_line_start = 1 + + elif jj == 1: + ref = 0 + diff = ref + difference + effsize_line_start = contrast_xlim_max - 1.1 + + xlimlow, xlimhigh = axx.get_xlim() + + # Draw reference line. + axx.hlines( + ref, # y-coordinates + 0, + xlimhigh, # x-coordinates, start and end. + **reflines_kwargs + ) + + # Draw effect size line. + axx.hlines(diff, effsize_line_start, xlimhigh, **reflines_kwargs) + else: + ref = 0 + diff = ref + difference + effsize_line_start = contrast_xlim_max - 0.9 + xlimlow, xlimhigh = contrast_axes.get_xlim() + # Draw reference line. + contrast_axes.hlines( + ref, # y-coordinates + effsize_line_start, + xlimhigh, # x-coordinates, start and end. + **reflines_kwargs + ) + + # Draw effect size line. + contrast_axes.hlines(diff, effsize_line_start, xlimhigh, **reflines_kwargs) + rawdata_axes.set_xlim(og_xlim_raw) # to align the axis + # Despine appropriately. + sns.despine(ax=rawdata_axes, bottom=True) + sns.despine(ax=contrast_axes, left=True, right=False) + + # Insert break between the rawdata axes and the contrast axes + # by re-drawing the x-spine. + rawdata_axes.hlines( + og_ylim_raw[0], # yindex + rawdata_axes.get_xlim()[0], + 1.3, # xmin, xmax + **redraw_axes_kwargs + ) + rawdata_axes.set_ylim(og_ylim_raw) + + contrast_axes.hlines( + contrast_axes.get_ylim()[0], + contrast_xlim_max - 0.8, + contrast_xlim_max, + **redraw_axes_kwargs + ) + + +def Cumming_Plot_Aesthetic_Adjustments(plot_kwargs, show_delta2, effect_size_type, contrast_axes, reflines_kwargs, + is_paired, show_pairs, two_col_sankey, idx, ticks_to_start_twocol_sankey, + proportional, ticks_to_skip, temp_idx, rawdata_axes, redraw_axes_kwargs, + ticks_to_skip_contrast): + # Set custom contrast_ylim, if it was specified. + if plot_kwargs["contrast_ylim"] is not None or ( + plot_kwargs["delta2_ylim"] is not None and show_delta2 + ): + if plot_kwargs["contrast_ylim"] is not None: + custom_contrast_ylim = plot_kwargs["contrast_ylim"] + if plot_kwargs["delta2_ylim"] is not None and show_delta2: + custom_delta2_ylim = plot_kwargs["delta2_ylim"] + if custom_contrast_ylim != custom_delta2_ylim: + err1 = "Please check if `contrast_ylim` and `delta2_ylim` are assigned" + err2 = "with same values." + raise ValueError(err1 + err2) + else: + custom_delta2_ylim = plot_kwargs["delta2_ylim"] + custom_contrast_ylim = custom_delta2_ylim + + if len(custom_contrast_ylim) != 2: + err1 = "Please check `contrast_ylim` consists of " + err2 = "exactly two numbers." + raise ValueError(err1 + err2) + + if effect_size_type == "cliffs_delta": + # Ensure the ylims for a cliffs_delta plot never exceed [-1, 1]. + l = plot_kwargs["contrast_ylim"][0] + h = plot_kwargs["contrast_ylim"][1] + low = -1 if l < -1 else l + high = 1 if h > 1 else h + contrast_axes.set_ylim(low, high) + else: + contrast_axes.set_ylim(custom_contrast_ylim) + + + # If 0 lies within the ylim of the contrast axes, + # draw a zero reference line. + contrast_axes_ylim = contrast_axes.get_ylim() + if contrast_axes_ylim[0] < contrast_axes_ylim[1]: + contrast_ylim_low, contrast_ylim_high = contrast_axes_ylim + else: + contrast_ylim_high, contrast_ylim_low = contrast_axes_ylim + if contrast_ylim_low < 0 < contrast_ylim_high: + contrast_axes.axhline(y=0, **reflines_kwargs) + + if is_paired == "baseline" and show_pairs: + if two_col_sankey: + rightend_ticks_raw = np.array([len(i) - 2 for i in idx]) + np.array( + ticks_to_start_twocol_sankey + ) + elif proportional and is_paired is not None: + rightend_ticks_raw = np.array([len(i) - 1 for i in idx]) + np.array( + ticks_to_skip + ) + else: + rightend_ticks_raw = np.array( + [len(i) - 1 for i in temp_idx] + ) + np.array(ticks_to_skip) + for ax in [rawdata_axes]: + sns.despine(ax=ax, bottom=True) + + ylim = ax.get_ylim() + xlim = ax.get_xlim() + redraw_axes_kwargs["y"] = ylim[0] + + if two_col_sankey: + for k, start_tick in enumerate(ticks_to_start_twocol_sankey): + end_tick = rightend_ticks_raw[k] + ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs) + else: + for k, start_tick in enumerate(ticks_to_skip): + end_tick = rightend_ticks_raw[k] + ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs) + ax.set_ylim(ylim) + del redraw_axes_kwargs["y"] + + if not proportional: + temp_length = [(len(i) - 1) for i in idx] + else: + temp_length = [(len(i) - 1) * 2 - 1 for i in idx] + if two_col_sankey: + rightend_ticks_contrast = np.array( + [len(i) - 2 for i in idx] + ) + np.array(ticks_to_start_twocol_sankey) + elif proportional and is_paired is not None: + rightend_ticks_contrast = np.array( + [len(i) - 1 for i in idx] + ) + np.array(ticks_to_skip) + else: + rightend_ticks_contrast = np.array(temp_length) + np.array( + ticks_to_skip_contrast + ) + for ax in [contrast_axes]: + sns.despine(ax=ax, bottom=True) + + ylim = ax.get_ylim() + xlim = ax.get_xlim() + redraw_axes_kwargs["y"] = ylim[0] + + if two_col_sankey: + for k, start_tick in enumerate(ticks_to_start_twocol_sankey): + end_tick = rightend_ticks_contrast[k] + ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs) + else: + for k, start_tick in enumerate(ticks_to_skip_contrast): + end_tick = rightend_ticks_contrast[k] + ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs) + + ax.set_ylim(ylim) + del redraw_axes_kwargs["y"] + else: + # Compute the end of each x-axes line. + if two_col_sankey: + rightend_ticks = np.array([len(i) - 2 for i in idx]) + np.array( + ticks_to_start_twocol_sankey + ) + else: + rightend_ticks = np.array([len(i) - 1 for i in idx]) + np.array( + ticks_to_skip + ) + + for ax in [rawdata_axes, contrast_axes]: + sns.despine(ax=ax, bottom=True) + + ylim = ax.get_ylim() + xlim = ax.get_xlim() + redraw_axes_kwargs["y"] = ylim[0] + + if two_col_sankey: + for k, start_tick in enumerate(ticks_to_start_twocol_sankey): + end_tick = rightend_ticks[k] + ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs) + else: + for k, start_tick in enumerate(ticks_to_skip): + end_tick = rightend_ticks[k] + ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs) + + ax.set_ylim(ylim) + del redraw_axes_kwargs["y"] + +def General_Plot_Aesthetic_Adjustments(show_delta2, show_mini_meta, contrast_axes, redraw_axes_kwargs, plot_kwargs, + yvar, effect_size_type, proportional, effectsize_df, is_paired, float_contrast, + rawdata_axes, og_ylim_raw, effect_size): + + if show_delta2 or show_mini_meta: + ylim = contrast_axes.get_ylim() + redraw_axes_kwargs["y"] = ylim[0] + x_ticks = contrast_axes.get_xticks() + contrast_axes.hlines(xmin=x_ticks[-2], xmax=x_ticks[-1], **redraw_axes_kwargs) + del redraw_axes_kwargs["y"] + + # Set raw axes y-label. + swarm_label = plot_kwargs["swarm_label"] + if swarm_label is None and yvar is None: + swarm_label = "value" + elif swarm_label is None and yvar is not None: + swarm_label = yvar + + bar_label = plot_kwargs["bar_label"] + if bar_label is None and effect_size_type != "cohens_h": + bar_label = "proportion of success" + elif bar_label is None and effect_size_type == "cohens_h": + bar_label = "value" + + # Place contrast axes y-label. + contrast_label_dict = { + "mean_diff": "mean difference", + "median_diff": "median difference", + "cohens_d": "Cohen's d", + "hedges_g": "Hedges' g", + "cliffs_delta": "Cliff's delta", + "cohens_h": "Cohen's h", + "delta_g": "mean difference", + } + + if proportional and effect_size_type != "cohens_h": + default_contrast_label = "proportion difference" + elif effect_size_type == "delta_g": + default_contrast_label = "Hedges' g" + else: + default_contrast_label = contrast_label_dict[effectsize_df.effect_size] + + if plot_kwargs["contrast_label"] is None: + if is_paired: + contrast_label = "paired\n{}".format(default_contrast_label) + else: + contrast_label = default_contrast_label + contrast_label = contrast_label.capitalize() + else: + contrast_label = plot_kwargs["contrast_label"] + + if plot_kwargs["fontsize_rawylabel"] is not None: + fontsize_rawylabel = plot_kwargs["fontsize_rawylabel"] + if plot_kwargs["fontsize_contrastylabel"] is not None: + fontsize_contrastylabel = plot_kwargs["fontsize_contrastylabel"] + if plot_kwargs["fontsize_delta2label"] is not None: + fontsize_delta2label = plot_kwargs["fontsize_delta2label"] + + contrast_axes.set_ylabel(contrast_label, fontsize=fontsize_contrastylabel) + if float_contrast: + contrast_axes.yaxis.set_label_position("right") + + # Set the rawdata axes labels appropriately + if not proportional: + rawdata_axes.set_ylabel(swarm_label, fontsize=fontsize_rawylabel) + else: + rawdata_axes.set_ylabel(bar_label, fontsize=fontsize_rawylabel) + rawdata_axes.set_xlabel("") + + # Because we turned the axes frame off, we also need to draw back + # the y-spine for both axes. + if not float_contrast: + rawdata_axes.set_xlim(contrast_axes.get_xlim()) + og_xlim_raw = rawdata_axes.get_xlim() + rawdata_axes.vlines( + og_xlim_raw[0], og_ylim_raw[0], og_ylim_raw[1], **redraw_axes_kwargs + ) + + og_xlim_contrast = contrast_axes.get_xlim() + + if float_contrast: + xpos = og_xlim_contrast[1] + else: + xpos = og_xlim_contrast[0] + + og_ylim_contrast = contrast_axes.get_ylim() + contrast_axes.vlines( + xpos, og_ylim_contrast[0], og_ylim_contrast[1], **redraw_axes_kwargs + ) + + if show_delta2: + if plot_kwargs["delta2_label"] is not None: + delta2_label = plot_kwargs["delta2_label"] + elif effect_size == "mean_diff": + delta2_label = "delta - delta" + else: + delta2_label = "deltas' g" + delta2_axes = contrast_axes.twinx() + delta2_axes.set_frame_on(False) + delta2_axes.set_ylabel(delta2_label, fontsize=fontsize_delta2label) + og_xlim_delta = contrast_axes.get_xlim() + og_ylim_delta = contrast_axes.get_ylim() + delta2_axes.set_ylim(og_ylim_delta) + delta2_axes.vlines( + og_xlim_delta[1], og_ylim_delta[0], og_ylim_delta[1], **redraw_axes_kwargs + ) diff --git a/dabest/plot_tools.py b/dabest/plot_tools.py index 65fea009..ed791429 100644 --- a/dabest/plot_tools.py +++ b/dabest/plot_tools.py @@ -1,3 +1,5 @@ +"""A set of convenience functions used for producing plots in `dabest`.""" + # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/plot_tools.ipynb. # %% ../nbs/API/plot_tools.ipynb 2 @@ -5,7 +7,9 @@ # %% auto 0 __all__ = ['halfviolin', 'get_swarm_spans', 'error_bar', 'check_data_matches_labels', 'normalize_dict', 'width_determine', - 'single_sankey', 'sankeydiag', 'swarmplot', 'SwarmPlot'] + 'single_sankey', 'sankeydiag', 'summary_bars_plotter', 'contrast_bars_plotter', 'swarm_bars_plotter', + 'delta_text_plotter', 'DeltaDotsPlotter', 'slopegraph_plotter', 'plot_minimeta_or_deltadelta_violins', + 'effect_size_curve_plotter', 'grid_key_WIP', 'barplotter', 'swarmplot', 'SwarmPlot'] # %% ../nbs/API/plot_tools.ipynb 4 import math @@ -17,6 +21,7 @@ import matplotlib.pyplot as plt import matplotlib.lines as mlines import matplotlib.axes as axes +import matplotlib.patches as mpatches from collections import defaultdict from typing import List, Tuple, Dict, Iterable, Union from pandas.api.types import CategoricalDtype @@ -173,7 +178,8 @@ def error_bar( kwargs["zorder"] = kwargs["zorder"] - for xpos, central_measure in enumerate(central_measures): + for xpos, val in enumerate(central_measures.index): + central_measure = central_measures[val] kwargs["color"] = custom_palette[xpos] if method == "sankey_error_bar": @@ -181,8 +187,14 @@ def error_bar( else: _xpos = xpos + offset[xpos] - low = lows[xpos] - high = highs[xpos] + # Fix for the non-string x-axis issue #108 + if central_measures.index.dtype.name == "category": + low = lows[xpos] + high = highs[xpos] + else: + low = lows[val] + high = highs[val] + if low == high == central_measure: low_to_mean = mlines.Line2D( [_xpos, _xpos], [low, central_measure], **kwargs @@ -353,7 +365,7 @@ def single_sankey( strip_on: bool = True, # if True, draw strip for each group comparison one_sankey: bool = False, # if True, only draw one sankey diagram right_color: bool = False, # if True, each strip of the diagram will be colored according to the corresponding left labels - align: bool = "center", # if 'center', the diagram will be centered on each xtick, if 'edge', the diagram will be aligned with the left edge of each xtick + align: str = "center", # if 'center', the diagram will be centered on each xtick, if 'edge', the diagram will be aligned with the left edge of each xtick ): """ Make a single Sankey diagram showing proportion flow from left to right @@ -434,6 +446,7 @@ def single_sankey( if align not in ("center", "edge"): err = "{} assigned for `align` is not valid.".format(align) raise ValueError(err) + if align == "center": try: leftpos = xpos - width / 2 @@ -620,8 +633,9 @@ def sankeydiag( data: pd.DataFrame, xvar: str, # x column to be plotted. yvar: str, # y column to be plotted. - left_idx: str, # the value in column xvar that is on the left side of each sankey diagram - right_idx: str, # the value in column xvar that is on the right side of each sankey diagram, if len(left_idx) == 1, it will be broadcasted to the same length as right_idx, otherwise it should have the same length as right_idx + temp_all_plot_groups: list, + idx: list, + temp_idx: list, left_labels: list = None, # labels for the left side of the diagram. The diagram will be sorted by these labels. right_labels: list = None, # labels for the right side of the diagram. The diagram will be sorted by these labels. palette: str | dict = None, @@ -667,6 +681,30 @@ def sankeydiag( if ax is None: ax = plt.gca() + left_idx = [] + right_idx = [] + # Design for Sankey Flow Diagram + sankey_idx = ( + [ + (control, test) + for i in idx + for control, test in zip(i[:], (i[1:] + (i[0],))) + ] + if flow + else temp_idx + ) + for i in sankey_idx: + left_idx.append(i[0]) + right_idx.append(i[1]) + + if len(temp_all_plot_groups) == 2: + one_sankey = True + left_idx.pop() + right_idx.pop() # Remove the last element from two lists + + # two_col_sankey = True if proportional == True and one_sankey == False and sankey == True and flow == False else False + + allLabels = pd.Series(np.sort(data[yvar].unique())[::-1]).unique() # Check if all the elements in left_idx and right_idx are in xvar column @@ -778,6 +816,671 @@ def sankeydiag( ax.set_xticks([0, 1]) ax.set_xticklabels(sankey_ticks) + return left_idx, right_idx + +def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object, + float_contrast: bool,summary_bars_kwargs: dict, ci_type: str, + ticks_to_plot: list, color_col: str, swarm_colors: list, + proportional: bool, is_paired: bool): + """ + Add summary bars to the contrast plot. + + Parameters + ---------- + summary_bars : list + List of indices of the contrast objects to plot summary bars for. + results : object (Dataframe) + Dataframe of contrast object comparisons. + ax_to_plot : object + Matplotlib axis object to plot on. + float_contrast : bool + Whether the DABEST plot uses Gardner-Altman or Cummings. + summary_bars_kwargs : dict + Keyword arguments for the summary bars. + ci_type : str + Type of confidence interval to plot. + ticks_to_plot : list + List of indices of the contrast objects. + color_col : str + Column name of the color column. + swarm_colors : list + List of colors used in the plot. + proportional : bool + Whether the data is proportional. + is_paired : bool + Whether the data is paired. + """ +# Begin checks + if not isinstance(summary_bars, list): + raise TypeError("summary_bars must be a list of indices (ints).") + if not all(isinstance(i, int) for i in summary_bars): + raise TypeError("summary_bars must be a list of indices (ints).") + if any(i >= len(results) for i in summary_bars): + raise ValueError("Index {} chosen is out of range for the contrast objects.".format([i for i in summary_bars if i >= len(results)])) + if float_contrast: + raise ValueError("summary_bars cannot be used with Gardner-Altman plots.") +# End checks + else: + summary_xmin, summary_xmax = ax_to_plot.get_xlim() + summary_bars_colors = [summary_bars_kwargs.get('color')]*(max(summary_bars)+1) if summary_bars_kwargs.get('color') is not None else ['black']*(max(summary_bars)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors + summary_bars_kwargs.pop('color') + for summary_index in summary_bars: + if ci_type == "bca": + summary_ci_low = results.bca_low[summary_index] + summary_ci_high = results.bca_high[summary_index] + else: + summary_ci_low = results.pct_low[summary_index] + summary_ci_high = results.pct_high[summary_index] + + summary_color = summary_bars_colors[ticks_to_plot[summary_index]] + + ax_to_plot.add_patch(mpatches.Rectangle((summary_xmin,summary_ci_low),summary_xmax+1, + summary_ci_high-summary_ci_low, zorder=-2, color=summary_color, **summary_bars_kwargs)) + + +def contrast_bars_plotter(results: object, ax_to_plot: object, swarm_plot_ax: object, + ticks_to_plot: list, contrast_bars_kwargs: dict, color_col: str, + plot_palette_raw: dict, show_mini_meta: bool, mini_meta_delta: object, + show_delta2: bool, delta_delta: object, proportional: bool, is_paired: bool): + """ + Add contrast bars to the contrast plot. + + Parameters + ---------- + results : object (Dataframe) + Dataframe of contrast object comparisons. + ax_to_plot : object + Matplotlib axis object to plot on. + swarm_plot_ax : object (ax) + Matplotlib axis object of the swarm plot. + ticks_to_plot : list + List of indices of the contrast objects. + contrast_bars_kwargs : dict + Keyword arguments for the contrast bars. + color_col : str + Column name of the color column. + plot_palette_raw : dict + Dictionary of colors used in the plot. + show_mini_meta : bool + Whether to show the mini meta-analysis. + mini_meta_delta : object + Mini meta-analysis object. + show_delta2 : bool + Whether to show the delta-delta. + delta_delta : object + delta-delta object. + proportional : bool + Whether the data is proportional. + is_paired : bool + Whether the data is paired. + """ + contrast_means = [] + for j, tick in enumerate(ticks_to_plot): + contrast_means.append(results.difference[j]) + + contrast_bars_colors = ( + [contrast_bars_kwargs.get('color')] * (max(ticks_to_plot) + 1) + if contrast_bars_kwargs.get('color') is not None + else ['black'] * (max(ticks_to_plot) + 1) + if color_col is not None or (proportional and is_paired) or is_paired + else list(plot_palette_raw.values()) + ) + contrast_bars_kwargs.pop('color') + for contrast_bars_x,contrast_bars_y in zip(ticks_to_plot, contrast_means): + ax_to_plot.add_patch(mpatches.Rectangle((contrast_bars_x-0.25,0),0.5, contrast_bars_y, zorder=-1, color=contrast_bars_colors[contrast_bars_x], **contrast_bars_kwargs)) + + if show_mini_meta: + ax_to_plot.add_patch(mpatches.Rectangle((max(swarm_plot_ax.get_xticks())+2-0.25,0),0.5, mini_meta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs)) + + if show_delta2: + ax_to_plot.add_patch(mpatches.Rectangle((max(swarm_plot_ax.get_xticks())+2-0.25,0),0.5, delta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs)) + +def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object, + swarm_bars_kwargs: dict, color_col: str, plot_palette_raw: dict, is_paired: bool): + """ + Add bars to the raw data plot. + + Parameters + ---------- + plot_data : object (Dataframe) + Dataframe of the plot data. + xvar : str + Column name of the x variable. + yvar : str + Column name of the y variable. + ax : object + Matplotlib axis object to plot on. + swarm_bars_kwargs : dict + Keyword arguments for the swarm bars. + color_col : str + Column name of the color column. + plot_palette_raw : dict + Dictionary of colors used in the plot. + is_paired : bool + Whether the data is paired. + """ + + # if is_paired: + # swarm_bar_xlocs_adjustleft = {'right': -0.2, 'left': -0.2, 'center': -0.2} + # swarm_bar_xlocs_adjustright = {'right': -0.1, 'left': -0.1, 'center': -0.1} + # else: + # swarm_bar_xlocs_adjustleft = {'right': 0, 'left': -0.4, 'center': -0.2} + # swarm_bar_xlocs_adjustright = {'right': -0.1, 'left': -0.1, 'center': -0.1} + + if isinstance(plot_data[xvar].dtype, pd.CategoricalDtype): + swarm_bars_order = pd.unique(plot_data[xvar]).categories + else: + swarm_bars_order = pd.unique(plot_data[xvar]) + + swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order) + # swarm_bars_colors = [swarm_bars_kwargs.get('color')]*(max(swarm_bars_order)+1) if swarm_bars_kwargs.get('color') is not None else ['black']*(len(swarm_bars_order)+1) if color_col is not None or is_paired else swarm_colors + swarm_bars_colors = ( + [swarm_bars_kwargs.get('color')] * (max(swarm_bars_order) + 1) + if swarm_bars_kwargs.get('color') is not None + else ['black']*(len(swarm_bars_order)+1) + if color_col is not None or is_paired + else list(plot_palette_raw.values()) + ) + swarm_bars_kwargs.pop('color') + for swarm_bars_x,swarm_bars_y,c in zip(np.arange(0,len(swarm_bars_order)+1,1), swarm_means, swarm_bars_colors): + ax.add_patch(mpatches.Rectangle((swarm_bars_x-0.25,0), + 0.5, swarm_bars_y, zorder=-1,color=c,**swarm_bars_kwargs)) + +def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: object, ticks_to_plot: list, delta_text_kwargs: dict, color_col: str, + swarm_colors: list, is_paired: bool, proportional: bool, float_contrast: bool, + show_mini_meta: bool, mini_meta_delta: object, show_delta2: bool, delta_delta: object): + """ + Add text to the contrast plot. + + Parameters + ---------- + results : object (Dataframe) + Dataframe of contrast object comparisons. + ax_to_plot : object + Matplotlib axis object to plot on. + swarm_plot_ax : object + Matplotlib axis object of the swarm plot. + ticks_to_plot : list + List of indices of the contrast objects. + delta_text_kwargs : dict + Keyword arguments for the delta text. + color_col : str + Column name of the color column. + swarm_colors : list + List of colors used in the plot. + is_paired : bool + Whether the data is paired. + proportional : bool + Whether the data is proportional. + float_contrast : bool + Whether the DABEST plot uses Gardner-Altman or Cummings + show_mini_meta : bool + Whether to show the mini meta-analysis. + mini_meta_delta : object + Mini meta-analysis object. + show_delta2 : bool + Whether to show the delta-delta. + delta_delta : object + delta-delta object. + """ + # Begin checks + delta_text_x_location = delta_text_kwargs.get('x_location') + if delta_text_x_location != 'right' and delta_text_x_location != 'left': + raise ValueError("delta_text_kwargs['x_location'] must be either 'right' or 'left'.") + if float_contrast: + delta_text_x_location = 'left' + delta_text_kwargs["va"] = 'bottom' if results.difference[0] >= 0 else 'top' + delta_text_kwargs.pop('x_location') + + delta_text_colors = [delta_text_kwargs.get('color')]*(max(ticks_to_plot)+1) if delta_text_kwargs.get('color') is not None else ['black']*(max(ticks_to_plot)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors + if show_mini_meta or show_delta2: delta_text_colors.append('black') + delta_text_kwargs.pop('color') + + total_ticks = len(ticks_to_plot) + 1 if show_mini_meta or show_delta2 else len(ticks_to_plot) + + # Collect the Y-values for the delta text + Delta_Values = [] + for j, tick in enumerate(ticks_to_plot): + Delta_Values.append(results.difference[j]) + if show_delta2: Delta_Values.append(delta_delta.difference) + if show_mini_meta: Delta_Values.append(mini_meta_delta.difference) + + # Collect the X-coordinates for the delta text + delta_text_x_coordinates = delta_text_kwargs.get('x_coordinates') + + if delta_text_x_coordinates is not None: + if not isinstance(delta_text_x_coordinates, list): + raise TypeError("delta_text_kwargs['x_coordinates'] must be a list of x-coordinates.") + if len(delta_text_x_coordinates) != len(total_ticks): + raise ValueError("delta_text_kwargs['x_coordinates'] must have the same length as the number of ticks to plot.") + else: + delta_text_x_coordinates = ticks_to_plot + X_Adjust = 0.48 if delta_text_x_location == 'right' else -0.38 + delta_text_x_coordinates = [x+X_Adjust for x in delta_text_x_coordinates] + if show_mini_meta: delta_text_x_coordinates.append(max(swarm_plot_ax.get_xticks())+2+X_Adjust) + if show_delta2: delta_text_x_coordinates.append(max(swarm_plot_ax.get_xticks())+2-0.35) + if show_mini_meta or show_delta2: ticks_to_plot.append(max(ticks_to_plot)+1) + delta_text_kwargs.pop('x_coordinates') + + # Collect the Y-coordinates for the delta text + delta_text_y_coordinates = delta_text_kwargs.get('y_coordinates') + + if delta_text_y_coordinates is not None: + if not isinstance(delta_text_y_coordinates, list): + raise TypeError("delta_text_kwargs['y_coordinates'] must be a list of y-coordinates.") + if len(delta_text_y_coordinates) != len(total_ticks): + raise ValueError("delta_text_kwargs['y_coordinates'] must have the same length as the number of ticks to plot.") + else: + delta_text_y_coordinates = Delta_Values + + delta_text_kwargs.pop('y_coordinates') + + # Plot the delta text + for x,y,t,tick in zip(delta_text_x_coordinates, delta_text_y_coordinates,Delta_Values,ticks_to_plot): + Delta_Text = np.format_float_positional(t, precision=2, sign=True, trim="k", min_digits=2) + ax_to_plot.text(x, y, Delta_Text, color=delta_text_colors[tick], zorder=5, **delta_text_kwargs) + + +def DeltaDotsPlotter(plot_data, contrast_axes, delta_id_col, idx, xvar, yvar, is_paired, color_col, float_contrast, plot_palette_raw, delta_dot_kwargs): + """ + Parameters + ---------- + plot_data : object (Dataframe) + Dataframe of the plot data. + contrast_axes : object + Matplotlib axis object to plot on. + delta_id_col : str + Column name of the delta id column. + idx : list + List of indices of the contrast objects. + xvar : str + Column name of the x variable. + yvar : str + Column name of the y variable. + is_paired : bool + Whether the data is paired. + color_col : str + Column name of the color column. + float_contrast : bool + Whether the DABEST plot uses Gardner-Altman or Cummings + plot_palette_raw : dict + Dictionary of colors used in the plot. + delta_dot_kwargs : dict + Keyword arguments for the delta dots. + """ + + # Checks and initializations + from .plot_tools import swarmplot + + if color_col is not None: + plot_palette_deltapts = plot_palette_raw + delta_plot_data = plot_data[[xvar, yvar, delta_id_col, color_col]] + else: + plot_palette_deltapts = "k" + delta_plot_data = plot_data[[xvar, yvar, delta_id_col]] + + # TODO: to make jitter value more accurate and not just a hardcoded eyeball value + jitter = 0.6 if float_contrast else 1 + + # Create dataframe of delta values + final_deltas = pd.DataFrame() + for i in idx: + for j in i: + if i.index(j) != 0: + temp_df_exp = delta_plot_data[ + delta_plot_data[xvar].str.contains(j) + ].reset_index(drop=True) + if is_paired == "baseline": + temp_df_cont = delta_plot_data[ + delta_plot_data[xvar].str.contains(i[0]) + ].reset_index(drop=True) + elif is_paired == "sequential": + temp_df_cont = delta_plot_data[ + delta_plot_data[xvar].str.contains( + i[i.index(j) - 1] + ) + ].reset_index(drop=True) + delta_df = temp_df_exp.copy() + delta_df[yvar] = temp_df_exp[yvar] - temp_df_cont[yvar] + final_deltas = pd.concat([final_deltas, delta_df]) + + # Plot the delta dots + swarmplot( + data=final_deltas, + x=xvar, + y=yvar, + ax=contrast_axes, + order=None, + hue=color_col, + palette=plot_palette_deltapts, + jitter=jitter, + is_drop_gutter=True, + gutter_limit=1, + **delta_dot_kwargs) + contrast_axes.legend().set_visible(False) + + +def slopegraph_plotter(dabest_obj, plot_data, xvar, yvar, color_col, plot_palette_raw, slopegraph_kwargs, rawdata_axes, ytick_color, temp_idx): + + # Pivot the long (melted) data. + if color_col is None: + pivot_values = [yvar] + else: + pivot_values = [yvar, color_col] + pivoted_plot_data = pd.pivot( + data=plot_data, + index=dabest_obj.id_col, + columns=xvar, + values=pivot_values, + ) + + x_start = 0 + for ii, current_tuple in enumerate(temp_idx): + current_pair = pivoted_plot_data.loc[ + :, pd.MultiIndex.from_product([pivot_values, current_tuple]) + ].dropna() + grp_count = len(current_tuple) + # Iterate through the data for the current tuple. + for ID, observation in current_pair.iterrows(): + x_points = [t for t in range(x_start, x_start + grp_count)] + y_points = observation[yvar].tolist() + + if color_col is None: + slopegraph_kwargs["color"] = ytick_color + else: + color_key = observation[color_col][0] + if isinstance(color_key, (str, np.int64, np.float64)): + slopegraph_kwargs["color"] = plot_palette_raw[color_key] + slopegraph_kwargs["label"] = color_key + + rawdata_axes.plot(x_points, y_points, **slopegraph_kwargs) + + x_start = x_start + grp_count + +def plot_minimeta_or_deltadelta_violins(show_mini_meta, effectsize_df, ci_type, rawdata_axes, + contrast_axes, violinplot_kwargs, halfviolin_alpha, ytick_color, + es_marker_size, group_summary_kwargs, contrast_xtick_labels, effect_size + ): + if show_mini_meta: + mini_meta_delta = effectsize_df.mini_meta_delta + data = mini_meta_delta.bootstraps_weighted_delta + difference = mini_meta_delta.difference + if ci_type == "bca": + ci_low = mini_meta_delta.bca_low + ci_high = mini_meta_delta.bca_high + else: + ci_low = mini_meta_delta.pct_low + ci_high = mini_meta_delta.pct_high + else: + delta_delta = effectsize_df.delta_delta + data = delta_delta.bootstraps_delta_delta + difference = delta_delta.difference + if ci_type == "bca": + ci_low = delta_delta.bca_low + ci_high = delta_delta.bca_high + else: + ci_low = delta_delta.pct_low + ci_high = delta_delta.pct_high + # Create the violinplot. + # New in v0.2.6: drop negative infinities before plotting. + position = max(rawdata_axes.get_xticks()) + 2 + v = contrast_axes.violinplot( + data[~np.isinf(data)], positions=[position], **violinplot_kwargs + ) + + fc = "grey" + + halfviolin(v, fill_color=fc, alpha=halfviolin_alpha) + + # Plot the effect size. + contrast_axes.plot( + [position], + difference, + marker="o", + color=ytick_color, + markersize=es_marker_size, + ) + # Plot the confidence interval. + contrast_axes.plot( + [position, position], + [ci_low, ci_high], + linestyle="-", + color=ytick_color, + linewidth=group_summary_kwargs["lw"], + ) + if show_mini_meta: + contrast_xtick_labels.extend(["", "Weighted delta"]) + elif effect_size == "delta_g": + contrast_xtick_labels.extend(["", "deltas' g"]) + else: + contrast_xtick_labels.extend(["", "delta-delta"]) + + return contrast_xtick_labels + + +def effect_size_curve_plotter(ticks_to_plot, results, ci_type, contrast_axes, violinplot_kwargs, halfviolin_alpha, + ytick_color, es_marker_size, group_summary_kwargs, bootstraps_color_by_group, plot_palette_contrast): + contrast_xtick_labels = [] + for j, tick in enumerate(ticks_to_plot): + current_group = results.test[j] + current_control = results.control[j] + current_bootstrap = results.bootstraps[j] + current_effsize = results.difference[j] + if ci_type == "bca": + current_ci_low = results.bca_low[j] + current_ci_high = results.bca_high[j] + else: + current_ci_low = results.pct_low[j] + current_ci_high = results.pct_high[j] + + # Create the violinplot. + # New in v0.2.6: drop negative infinities before plotting. + v = contrast_axes.violinplot( + current_bootstrap[~np.isinf(current_bootstrap)], + positions=[tick], + **violinplot_kwargs + ) + # Turn the violinplot into half, and color it the same as the swarmplot. + # Do this only if the color column is not specified. + # Ideally, the alpha (transparency) fo the violin plot should be + # less than one so the effect size and CIs are visible. + if bootstraps_color_by_group: + fc = plot_palette_contrast[current_group] + else: + fc = "grey" + + halfviolin(v, fill_color=fc, alpha=halfviolin_alpha) + + # Plot the effect size. + contrast_axes.plot( + [tick], + current_effsize, + marker="o", + color=ytick_color, + markersize=es_marker_size, + ) + + # Plot the confidence interval. + contrast_axes.plot( + [tick, tick], + [current_ci_low, current_ci_high], + linestyle="-", + color=ytick_color, + linewidth=group_summary_kwargs["lw"], + ) + + contrast_xtick_labels.append( + "{}\nminus\n{}".format(current_group, current_control) + ) + return current_group, current_control, current_effsize, contrast_xtick_labels + + +def grid_key_WIP(is_paired, idx, all_plot_groups, gridkey_rows, rawdata_axes, contrast_axes, + plot_data, xvar, yvar, results, show_delta2, show_mini_meta, float_contrast, plot_kwargs,): + + gridkey_show_Ns=plot_kwargs["gridkey_show_Ns"] + gridkey_show_es=plot_kwargs["gridkey_show_es"] + gridkey_merge_pairs=plot_kwargs["gridkey_merge_pairs"] + + # Raise error if there are more than 2 items in any idx and gridkey_merge_pairs is True and is_paired is not None + if gridkey_merge_pairs and is_paired is not None: + for i in idx: + if len(i) > 2: + warnings.warn( + "gridkey_merge_pairs=True only works if all idx in tuples have only two items. gridkey_merge_pairs has automatically been set to False" + ) + gridkey_merge_pairs = False + break + elif gridkey_merge_pairs and is_paired is None: + warnings.warn( + "gridkey_merge_pairs=True is only applicable for paired data." + ) + gridkey_merge_pairs = False + + # Checks for gridkey_merge_pairs and is_paired; if both are true, "merges" the gridkey per pair + if gridkey_merge_pairs and is_paired is not None: + groups_for_gridkey = [] + for i in idx: + groups_for_gridkey.append(i[1]) + else: + groups_for_gridkey = all_plot_groups + + # raise errors if gridkey_rows is not a list, or if the list is empty + if isinstance(gridkey_rows, list) is False: + raise TypeError("gridkey_rows must be a list.") + elif len(gridkey_rows) == 0: + warnings.warn("gridkey_rows is an empty list.") + + # raise Warning if an item in gridkey_rows is not contained in any idx + for i in gridkey_rows: + in_idx = 0 + for j in groups_for_gridkey: + if i in j: + in_idx += 1 + if in_idx == 0: + if is_paired is not None: + warnings.warn( + i + + " is not in any idx. Please check. Alternatively, merging gridkey pairs may not be suitable for your data; try passing gridkey_merge_pairs=False." + ) + else: + warnings.warn(i + " is not in any idx. Please check.") + + # Populate table: checks if idx for each column contains rowlabel name + # IF so, marks that element as present w black dot, or space if not present + table_cellcols = [] + for i in gridkey_rows: + thisrow = [] + for q in groups_for_gridkey: + if str(i) in q: + thisrow.append("\u25CF") + else: + thisrow.append("") + table_cellcols.append(thisrow) + + # Adds a row for Ns with the Ns values + if gridkey_show_Ns: + gridkey_rows.append("Ns") + list_of_Ns = [] + for i in groups_for_gridkey: + list_of_Ns.append(str(plot_data.groupby(xvar).count()[yvar].loc[i])) + table_cellcols.append(list_of_Ns) + + # Adds a row for effectsizes with effectsize values + if gridkey_show_es: + gridkey_rows.append("\u0394") + effsize_list = [] + results_list = results.test.to_list() + + # get the effect size, append + or -, 2 dec places + for i in enumerate(groups_for_gridkey): + if i[1] in results_list: + curr_esval = results.loc[results["test"] == i[1]][ + "difference" + ].iloc[0] + curr_esval_str = np.format_float_positional( + curr_esval, + precision=2, + sign=True, + trim="k", + min_digits=2, + ) + effsize_list.append(curr_esval_str) + else: + effsize_list.append("-") + + table_cellcols.append(effsize_list) + + # If Gardner-Altman plot, plot on raw data and not contrast axes + if float_contrast: + axes_ploton = rawdata_axes + else: + axes_ploton = contrast_axes + + # Account for extended x axis in case of show_delta2 or show_mini_meta + x_groups_for_width = len(groups_for_gridkey) + if show_delta2 or show_mini_meta: + x_groups_for_width += 2 + gridkey_width = len(groups_for_gridkey) / x_groups_for_width + + gridkey = axes_ploton.table( + cellText=table_cellcols, + rowLabels=gridkey_rows, + cellLoc="center", + bbox=[ + 0, + -len(gridkey_rows) * 0.1 - 0.05, + gridkey_width, + len(gridkey_rows) * 0.1, + ], + **{"alpha": 0.5} + ) + + # modifies row label cells + for cell in gridkey._cells: + if cell[1] == -1: + gridkey._cells[cell].visible_edges = "open" + gridkey._cells[cell].set_text_props(**{"ha": "right"}) + + # turns off both x axes + rawdata_axes.get_xaxis().set_visible(False) + contrast_axes.get_xaxis().set_visible(False) + +def barplotter(xvar, yvar, all_plot_groups, rawdata_axes, plot_data, bar_color, plot_palette_bar, plot_kwargs, barplot_kwargs): + # Plot the raw data as a barplot. + bar1_df = pd.DataFrame( + {xvar: all_plot_groups, "proportion": np.ones(len(all_plot_groups))} + ) + bar1 = sns.barplot( + data=bar1_df, + x=xvar, + y="proportion", + ax=rawdata_axes, + order=all_plot_groups, + linewidth=2, + facecolor=(1, 1, 1, 0), + edgecolor=bar_color, + zorder=1, + ) + bar2 = sns.barplot( + data=plot_data, + x=xvar, + y=yvar, + ax=rawdata_axes, + order=all_plot_groups, + palette=plot_palette_bar, + zorder=1, + **barplot_kwargs + ) + # adjust the width of bars + bar_width = plot_kwargs["bar_width"] + for bar in bar1.patches: + x = bar.get_x() + width = bar.get_width() + centre = x + width / 2.0 + bar.set_x(centre - bar_width / 2.0) + bar.set_width(bar_width) + # %% ../nbs/API/plot_tools.ipynb 6 def swarmplot( data: pd.DataFrame, @@ -791,6 +1494,7 @@ def swarmplot( size: float = 5, side: str = "center", jitter: float = 1, + filled: Union[bool, List, Tuple] = True, is_drop_gutter: bool = True, gutter_limit: float = 0.5, **kwargs, @@ -823,6 +1527,11 @@ def swarmplot( The side on which points are swarmed ("center", "left", or "right"). Default is "center". jitter : int | float Determines the distance between points. Default is 1. + filled : bool | List | Tuple + Determines whether the dots in the swarmplot are filled or not. If set to False, + dots are not filled. If provided as a List or Tuple, it should contain boolean values, + each corresponding to a swarm group in order, indicating whether the dot should be + filled or not. is_drop_gutter : bool If True, drop points that hit the gutters; otherwise, readjust them. gutter_limit : int | float @@ -836,7 +1545,7 @@ def swarmplot( Matplotlib AxesSubplot object for which the swarm plot has been drawn on. """ s = SwarmPlot(data, x, y, ax, order, hue, palette, zorder, size, side, jitter) - ax = s.plot(is_drop_gutter, gutter_limit, ax, **kwargs) + ax = s.plot(is_drop_gutter, gutter_limit, ax, filled, **kwargs) return ax @@ -996,7 +1705,9 @@ def _check_errors( if not isinstance(self.__jitter, (int, float)): raise ValueError("`jitter` must be a scalar or float.") if not isinstance(self.__palette, (str, Iterable)): - raise ValueError("`palette` must be either a string indicating a color name or an Iterable.") + raise ValueError( + "`palette` must be either a string indicating a color name or an Iterable." + ) if self.__hue is not None and not isinstance(self.__hue, str): raise ValueError("`hue` must be either a string or None.") if self.__order is not None and not isinstance(self.__order, Iterable): @@ -1026,7 +1737,6 @@ def _check_errors( err = "`palette` cannot be an empty string. It must be either a string indicating a color name or an Iterable." raise ValueError(err) if isinstance(self.__palette, dict): - # TODO: to add detection of when dict length is less than size of unique_items for group_i, color_i in self.__palette.items(): if group_i not in pd.unique(data[color_col]): err = ( @@ -1036,8 +1746,10 @@ def _check_errors( ) raise IndexError(err) if isinstance(color_i, str) and color_i.strip() == "": - err = "The color mapping for {0} in `palette` is an empty string. It must contain a color name.".format(group_i) - raise ValueError(err) + err = "The color mapping for {0} in `palette` is an empty string. It must contain a color name.".format( + group_i + ) + raise ValueError(err) if side.lower() not in ["center", "right", "left"]: raise ValueError( @@ -1239,7 +1951,12 @@ def _adjust_gutter_points( return points_data def plot( - self, is_drop_gutter: bool, gutter_limit: float, ax: axes.Subplot, **kwargs + self, + is_drop_gutter: bool, + gutter_limit: float, + ax: axes.Subplot, + filled: Union[bool, List, Tuple], + **kwargs, ) -> axes.Subplot: """ Generate a swarm plot. @@ -1252,6 +1969,11 @@ def plot( The limit for points hitting the gutters. ax : axes.Subplot The matplotlib figure object to which the swarm plot will be added. + filled : bool | List | Tuple + Determines whether the dots in the swarmplot are filled or not. If set to False, + dots are not filled. If provided as a List or Tuple, it should contain boolean values, + each corresponding to a swarm group in order, indicating whether the dot should be + filled or not. **kwargs: Additional keyword arguments to be passed to the scatter plot. @@ -1265,12 +1987,28 @@ def plot( raise ValueError("`is_drop_gutter` must be a boolean.") if not isinstance(gutter_limit, (int, float)): raise ValueError("`gutter_limit` must be a scalar or float.") + if not isinstance(filled, (bool, list, tuple)): + raise ValueError("`filled` must be a boolean, list or tuple.") + + # More thorough input validation checks + if isinstance(filled, (list, tuple)): + if len(filled) != len(self.__order): + err = ( + "There are {0} unique values in `x` column in `data` " + "but `filled` has a length of {1}. If `filled` is a list " + "or a tuple, it must have the same length as the number of " + "unique values/groups in the `x` column of data." + ).format(len(self.__order), len(filled)) + raise ValueError(err) + if not all(isinstance(_, bool) for _ in filled): + raise ValueError("All values in `filled` must be a boolean.") # Assumptions are that self.__data_copy is already sorted according to self.__order x_position = ( 0 # x-coordinate of center of each individual swarm of the swarm plot ) x_tick_tabels = [] + for group_i, values_i in self.__data_copy.groupby(self.__x): x_new = [] values_i_y = values_i[self.__y] @@ -1308,6 +2046,10 @@ def plot( cmap = [] for cmap_group_i in cmap_values: cmap.append(self.__palette[cmap_group_i]) + + # WIP: legend for swarm plot + swarm_legend_kwargs = {'colors':cmap, 'labels':cmap_values, 'index':index} + cmap = ListedColormap(cmap) ax.scatter( values_i["x_new"], @@ -1319,19 +2061,43 @@ def plot( edgecolor="face", **kwargs, ) + else: # color swarms based on `x` column + if not isinstance(filled, bool): + facecolor = ( + "none" + if not filled[x_position - 1] + else self.__palette[group_i] + ) + else: + facecolor = "none" if not filled else self.__palette[group_i] + ax.scatter( values_i["x_new"], values_i[self.__y], s=self.__size, - c=self.__palette[group_i], zorder=self.__zorder, - edgecolor="face", + facecolor=facecolor, + edgecolor=self.__palette[group_i], + label=group_i, **kwargs, ) + # Handling of legends + # This is currently a workaround because c and cmap is unable to map the labels when calling scatter() + # labels has to be used to designate legend labels and handles in scatter() due to the potential calling of ax.get_legend_handles_labels() + if self.__hue is not None: + for cmap_group_i in self.__palette: + ax.scatter( + [], + [], + c=self.__palette[cmap_group_i], + label=cmap_group_i, + ) + handles, labels = ax.get_legend_handles_labels() + ax.get_xaxis().set_ticks(np.arange(x_position)) ax.get_xaxis().set_ticklabels(x_tick_tabels) - return ax + return ax, swarm_legend_kwargs if self.__hue is not None else None diff --git a/dabest/plotter.py b/dabest/plotter.py index fcd65ee5..f2c30a1a 100644 --- a/dabest/plotter.py +++ b/dabest/plotter.py @@ -1,3 +1,5 @@ +"""Creating estimation plots.""" + # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/plotter.ipynb. # %% auto 0 @@ -8,6 +10,8 @@ import seaborn as sns import matplotlib import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.lines import Line2D import pandas as pd import warnings import logging @@ -52,19 +56,41 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs): title=None, fontsize_title=16, fontsize_rawxlabel=12, fontsize_rawylabel=12, fontsize_contrastxlabel=12, fontsize_contrastylabel=12, - fontsize_delta2label=12 + fontsize_delta2label=12, + swarm_bars=True, swarm_bars_kwargs=None, + contrast_bars=True, contrast_bars_kwargs=None, + delta_text=True, delta_text_kwargs=None, + delta_dot=True, delta_dot_kwargs=None, """ - from .misc_tools import merge_two_dicts + from .misc_tools import ( + get_params, + get_kwargs, + get_color_palette, + initialize_fig, + get_plot_groups, + add_counts_to_ticks, + extract_contrast_plotting_ticks, + set_xaxis_ticks_and_lims, + show_legend, + Gardner_Altman_Plot_Aesthetic_Adjustments, + Cumming_Plot_Aesthetic_Adjustments, + General_Plot_Aesthetic_Adjustments, + ) from .plot_tools import ( - halfviolin, get_swarm_spans, error_bar, sankeydiag, swarmplot, - ) - from ._stats_tools.effsize import ( - _compute_standardizers, - _compute_hedges_correction_factor, + swarm_bars_plotter, + contrast_bars_plotter, + summary_bars_plotter, + delta_text_plotter, + DeltaDotsPlotter, + slopegraph_plotter, + plot_minimeta_or_deltadelta_violins, + effect_size_curve_plotter, + grid_key_WIP, + barplotter, ) warnings.filterwarnings( @@ -82,499 +108,95 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs): original_rcParams[parameter] = plt.rcParams[parameter] plt.rcParams["axes.grid"] = False - ytick_color = plt.rcParams["ytick.color"] - face_color = plot_kwargs["face_color"] - - if plot_kwargs["face_color"] is None: - face_color = "white" - - dabest_obj = effectsize_df.dabest_obj - plot_data = effectsize_df._plot_data - xvar = effectsize_df.xvar - yvar = effectsize_df.yvar - is_paired = effectsize_df.is_paired - delta2 = effectsize_df.delta2 - mini_meta = effectsize_df.mini_meta - effect_size = effectsize_df.effect_size - proportional = effectsize_df.proportional - - all_plot_groups = dabest_obj._all_plot_groups - idx = dabest_obj.idx - - if effect_size not in ["mean_diff", "delta_g"] or not delta2: - show_delta2 = False - else: - show_delta2 = plot_kwargs["show_delta2"] - - if effect_size != "mean_diff" or not mini_meta: - show_mini_meta = False - else: - show_mini_meta = plot_kwargs["show_mini_meta"] - if show_delta2 and show_mini_meta: - err0 = "`show_delta2` and `show_mini_meta` cannot be True at the same time." - raise ValueError(err0) + # Extract parameters and set kwargs + (dabest_obj, plot_data, xvar, yvar, is_paired, effect_size, + proportional, all_plot_groups, idx, show_delta2, show_mini_meta, + float_contrast, show_pairs, effect_size_type, group_summaries, err_color) = get_params( + effectsize_df=effectsize_df, + plot_kwargs=plot_kwargs + ) + + (swarmplot_kwargs, barplot_kwargs, sankey_kwargs, violinplot_kwargs, + slopegraph_kwargs, reflines_kwargs, legend_kwargs, group_summary_kwargs, redraw_axes_kwargs, + delta_dot_kwargs, delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs) = get_kwargs( + plot_kwargs=plot_kwargs, + ytick_color=ytick_color + ) - # Disable Gardner-Altman plotting if any of the idxs comprise of more than - # two groups or if it is a delta-delta plot. - float_contrast = plot_kwargs["float_contrast"] - effect_size_type = effectsize_df.effect_size - if len(idx) > 1 or len(idx[0]) > 2: - float_contrast = False - - if effect_size_type in ["cliffs_delta"]: - float_contrast = False - - if show_delta2 or show_mini_meta: - float_contrast = False - - if not is_paired: - show_pairs = False - else: - show_pairs = plot_kwargs["show_pairs"] - - # Set default kwargs first, then merge with user-dictated ones. - # Swarmplot kwargs - default_swarmplot_kwargs = {"size": plot_kwargs["raw_marker_size"]} - if plot_kwargs["swarmplot_kwargs"] is None: - swarmplot_kwargs = default_swarmplot_kwargs - else: - swarmplot_kwargs = merge_two_dicts( - default_swarmplot_kwargs, plot_kwargs["swarmplot_kwargs"] - ) - asymmetric_side = ( - "left" # TODO: allow users to control side for swarms of swarmplot. - ) - - # Barplot kwargs - default_barplot_kwargs = {"estimator": np.mean, "errorbar": plot_kwargs["ci"]} - - if plot_kwargs["barplot_kwargs"] is None: - barplot_kwargs = default_barplot_kwargs - else: - barplot_kwargs = merge_two_dicts( - default_barplot_kwargs, plot_kwargs["barplot_kwargs"] - ) - - # Sankey Diagram kwargs - default_sankey_kwargs = { - "width": 0.4, - "align": "center", - "sankey": True, - "flow": True, - "alpha": 0.4, - "rightColor": False, - "bar_width": 0.2, - } - if plot_kwargs["sankey_kwargs"] is None: - sankey_kwargs = default_sankey_kwargs - else: - sankey_kwargs = merge_two_dicts( - default_sankey_kwargs, plot_kwargs["sankey_kwargs"] - ) # We also need to extract the `sankey` and `flow` from the kwargs for plotter.py # to use for varying different kinds of paired proportional plots # We also don't want to pop the parameter from the kwargs - sankey = sankey_kwargs["sankey"] - flow = sankey_kwargs["flow"] - - # Violinplot kwargs. - default_violinplot_kwargs = { - "widths": 0.5, - "vert": True, - "showextrema": False, - "showmedians": False, - } - if plot_kwargs["violinplot_kwargs"] is None: - violinplot_kwargs = default_violinplot_kwargs - else: - violinplot_kwargs = merge_two_dicts( - default_violinplot_kwargs, plot_kwargs["violinplot_kwargs"] - ) - - # Slopegraph kwargs. - default_slopegraph_kwargs = {"linewidth": 1, "alpha": 0.5} - if plot_kwargs["slopegraph_kwargs"] is None: - slopegraph_kwargs = default_slopegraph_kwargs - else: - slopegraph_kwargs = merge_two_dicts( - default_slopegraph_kwargs, plot_kwargs["slopegraph_kwargs"] - ) - - # Zero reference-line kwargs. - default_reflines_kwargs = { - "linestyle": "solid", - "linewidth": 0.75, - "zorder": 2, - "color": ytick_color, - } - if plot_kwargs["reflines_kwargs"] is None: - reflines_kwargs = default_reflines_kwargs - else: - reflines_kwargs = merge_two_dicts( - default_reflines_kwargs, plot_kwargs["reflines_kwargs"] - ) - - # Legend kwargs. - default_legend_kwargs = {"loc": "upper left", "frameon": False} - if plot_kwargs["legend_kwargs"] is None: - legend_kwargs = default_legend_kwargs - else: - legend_kwargs = merge_two_dicts( - default_legend_kwargs, plot_kwargs["legend_kwargs"] - ) - - ################################################### GRIDKEY WIP - extracting arguments - - gridkey_rows = plot_kwargs["gridkey_rows"] - gridkey_merge_pairs = plot_kwargs["gridkey_merge_pairs"] - gridkey_show_Ns = plot_kwargs["gridkey_show_Ns"] - gridkey_show_es = plot_kwargs["gridkey_show_es"] - - if gridkey_rows is None: - gridkey_show_Ns = False - gridkey_show_es = False - - ################################################### END GRIDKEY WIP - extracting arguments - - # Group summaries kwargs. - gs_default = {"mean_sd", "median_quartiles", None} - if plot_kwargs["group_summaries"] not in gs_default: - raise ValueError( - "group_summaries must be one of" " these: {}.".format(gs_default) - ) - - default_group_summary_kwargs = {"zorder": 3, "lw": 2, "alpha": 1} - if plot_kwargs["group_summary_kwargs"] is None: - group_summary_kwargs = default_group_summary_kwargs - else: - group_summary_kwargs = merge_two_dicts( - default_group_summary_kwargs, plot_kwargs["group_summary_kwargs"] - ) - - # Create color palette that will be shared across subplots. - color_col = plot_kwargs["color_col"] - if color_col is None: - color_groups = pd.unique(plot_data[xvar]) - bootstraps_color_by_group = True - else: - if color_col not in plot_data.columns: - raise KeyError("``{}`` is not a column in the data.".format(color_col)) - color_groups = pd.unique(plot_data[color_col]) - bootstraps_color_by_group = False - if show_pairs: - bootstraps_color_by_group = False - - # Handle the color palette. - names = color_groups - n_groups = len(color_groups) - custom_pal = plot_kwargs["custom_palette"] - swarm_desat = plot_kwargs["swarm_desat"] - bar_desat = plot_kwargs["bar_desat"] - contrast_desat = plot_kwargs["halfviolin_desat"] - - if custom_pal is None: - unsat_colors = sns.color_palette(n_colors=n_groups) - else: - if isinstance(custom_pal, dict): - groups_in_palette = { - k: v for k, v in custom_pal.items() if k in color_groups - } - - names = groups_in_palette.keys() - unsat_colors = groups_in_palette.values() - - elif isinstance(custom_pal, list): - unsat_colors = custom_pal[0:n_groups] - - elif isinstance(custom_pal, str): - # check it is in the list of matplotlib palettes. - if custom_pal in plt.colormaps(): - unsat_colors = sns.color_palette(custom_pal, n_groups) - else: - err1 = "The specified `custom_palette` {}".format(custom_pal) - err2 = " is not a matplotlib palette. Please check." - raise ValueError(err1 + err2) - - if custom_pal is None and color_col is None: - swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors] - plot_palette_raw = dict(zip(names.categories, swarm_colors)) - - bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors] - plot_palette_bar = dict(zip(names.categories, bar_color)) - - contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors] - plot_palette_contrast = dict(zip(names.categories, contrast_colors)) - - # For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors - # default color palette will be set to "hls" - plot_palette_sankey = None - - else: - swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors] - plot_palette_raw = dict(zip(names, swarm_colors)) - - bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors] - plot_palette_bar = dict(zip(names, bar_color)) - - contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors] - plot_palette_contrast = dict(zip(names, contrast_colors)) - - plot_palette_sankey = custom_pal - - # Infer the figsize. - fig_size = plot_kwargs["fig_size"] - if fig_size is None: - all_groups_count = np.sum([len(i) for i in dabest_obj.idx]) - # Increase the width for delta-delta graph - if show_delta2 or show_mini_meta: - all_groups_count += 2 - if is_paired and show_pairs and proportional is False: - frac = 0.75 - else: - frac = 1 - if float_contrast: - height_inches = 4 - each_group_width_inches = 2.5 * frac - else: - height_inches = 6 - each_group_width_inches = 1.5 * frac - - width_inches = each_group_width_inches * all_groups_count - fig_size = (width_inches, height_inches) - - # Initialise the figure. - init_fig_kwargs = dict(figsize=fig_size, dpi=plot_kwargs["dpi"], tight_layout=True) - - width_ratios_ga = [2.5, 1] - - ###################### GRIDKEY HSPACE ALTERATION - - # Sets hspace for cummings plots if gridkey is shown. - if gridkey_rows is not None: - h_space_cummings = 0.1 - else: - h_space_cummings = 0.3 - - ###################### END GRIDKEY HSPACE ALTERATION - - if plot_kwargs["ax"] is not None: - # New in v0.2.6. - # Use inset axes to create the estimation plot inside a single axes. - # Author: Adam L Nekimken. (PR #73) - rawdata_axes = plot_kwargs["ax"] - ax_position = rawdata_axes.get_position() # [[x0, y0], [x1, y1]] - - fig = rawdata_axes.get_figure() - fig.patch.set_facecolor(face_color) - - if float_contrast: - axins = rawdata_axes.inset_axes( - [1, 0, width_ratios_ga[1] / width_ratios_ga[0], 1] - ) - rawdata_axes.set_position( # [l, b, w, h] - [ - ax_position.x0, - ax_position.y0, - (ax_position.x1 - ax_position.x0) - * (width_ratios_ga[0] / sum(width_ratios_ga)), - (ax_position.y1 - ax_position.y0), - ] - ) - - contrast_axes = axins - - else: - axins = rawdata_axes.inset_axes([0, -1 - h_space_cummings, 1, 1]) - plot_height = (ax_position.y1 - ax_position.y0) / (2 + h_space_cummings) - rawdata_axes.set_position( - [ - ax_position.x0, - ax_position.y0 + (1 + h_space_cummings) * plot_height, - (ax_position.x1 - ax_position.x0), - plot_height, - ] - ) - - contrast_axes = axins - rawdata_axes.contrast_axes = axins - - else: - # Here, we hardcode some figure parameters. - if float_contrast: - fig, axx = plt.subplots( - ncols=2, - gridspec_kw={"width_ratios": width_ratios_ga, "wspace": 0}, - **init_fig_kwargs - ) - fig.patch.set_facecolor(face_color) - - else: - fig, axx = plt.subplots( - nrows=2, gridspec_kw={"hspace": h_space_cummings}, **init_fig_kwargs - ) - fig.patch.set_facecolor(face_color) - - # Title - title = plot_kwargs["title"] - fontsize_title = plot_kwargs["fontsize_title"] - if title is not None: - fig.suptitle(title, fontsize=fontsize_title) - rawdata_axes = axx[0] - contrast_axes = axx[1] - rawdata_axes.set_frame_on(False) - contrast_axes.set_frame_on(False) - - redraw_axes_kwargs = { - "colors": ytick_color, - "facecolors": ytick_color, - "lw": 1, - "zorder": 10, - "clip_on": False, - } - - swarm_ylim = plot_kwargs["swarm_ylim"] - - if swarm_ylim is not None: - rawdata_axes.set_ylim(swarm_ylim) - one_sankey = ( False if is_paired is not None else None ) # Flag to indicate if only one sankey is plotted. two_col_sankey = ( - True if proportional and not one_sankey and sankey and not flow else False + True if proportional and not one_sankey and sankey_kwargs["sankey"] and not sankey_kwargs["flow"] else False ) - if show_pairs: - # Determine temp_idx based on is_paired and proportional conditions - if is_paired == "baseline": - idx_pairs = [ - (control, test) - for i in idx - for control, test in zip([i[0]] * (len(i) - 1), i[1:]) - ] - temp_idx = idx if not proportional else idx_pairs - else: - idx_pairs = [ - (control, test) for i in idx for control, test in zip(i[:-1], i[1:]) - ] - temp_idx = idx if not proportional else idx_pairs - - # Determine temp_all_plot_groups based on proportional condition - plot_groups = [item for i in temp_idx for item in i] - temp_all_plot_groups = all_plot_groups if not proportional else plot_groups + # Extract Color palette + (color_col, bootstraps_color_by_group, n_groups, filled, + swarm_colors, plot_palette_raw, bar_color, + plot_palette_bar, plot_palette_contrast, plot_palette_sankey) = get_color_palette( + plot_kwargs=plot_kwargs, + plot_data=plot_data, + xvar=xvar, + show_pairs=show_pairs, + idx=idx + ) + # Initialise the figure. + fig, rawdata_axes, contrast_axes, swarm_ylim = initialize_fig( + plot_kwargs=plot_kwargs, + dabest_obj=dabest_obj, + show_delta2=show_delta2, + show_mini_meta=show_mini_meta, + is_paired=is_paired, + show_pairs=show_pairs, + proportional=proportional, + float_contrast=float_contrast, + ) + + # Plotting the rawdata. + if show_pairs: + temp_idx, temp_all_plot_groups = get_plot_groups( + is_paired=is_paired, + idx=idx, + proportional=proportional, + all_plot_groups=all_plot_groups + ) if not proportional: # Plot the raw data as a slopegraph. - # Pivot the long (melted) data. - if color_col is None: - pivot_values = [yvar] - else: - pivot_values = [yvar, color_col] - pivoted_plot_data = pd.pivot( - data=plot_data, - index=dabest_obj.id_col, - columns=xvar, - values=pivot_values, - ) - x_start = 0 - for ii, current_tuple in enumerate(temp_idx): - current_pair = pivoted_plot_data.loc[ - :, pd.MultiIndex.from_product([pivot_values, current_tuple]) - ].dropna() - grp_count = len(current_tuple) - # Iterate through the data for the current tuple. - for ID, observation in current_pair.iterrows(): - x_points = [t for t in range(x_start, x_start + grp_count)] - y_points = observation[yvar].tolist() - - if color_col is None: - slopegraph_kwargs["color"] = ytick_color - else: - color_key = observation[color_col][0] - if isinstance(color_key, (str, np.int64, np.float64)): - slopegraph_kwargs["color"] = plot_palette_raw[color_key] - slopegraph_kwargs["label"] = color_key - - rawdata_axes.plot(x_points, y_points, **slopegraph_kwargs) - - x_start = x_start + grp_count - - ##################### DELTA PTS ON CONTRAST PLOT WIP - - contrast_show_deltas = plot_kwargs["contrast_show_deltas"] - - if is_paired is None: - contrast_show_deltas = False - - if contrast_show_deltas: - delta_plot_data_temp = plot_data.copy() - delta_id_col = dabest_obj.id_col - if color_col is not None: - plot_palette_deltapts = plot_palette_raw - delta_plot_data = delta_plot_data_temp[ - [xvar, yvar, delta_id_col, color_col] - ] - deltapts_args = { - "marker": "^", - "alpha": 0.5, - } - - else: - plot_palette_deltapts = "k" - delta_plot_data = delta_plot_data_temp[[xvar, yvar, delta_id_col]] - deltapts_args = {"marker": "^", "alpha": 0.5} - - final_deltas = pd.DataFrame() - for i in idx: - for j in i: - if i.index(j) != 0: - temp_df_exp = delta_plot_data[ - delta_plot_data[xvar].str.contains(j) - ].reset_index(drop=True) - if is_paired == "baseline": - temp_df_cont = delta_plot_data[ - delta_plot_data[xvar].str.contains(i[0]) - ].reset_index(drop=True) - elif is_paired == "sequential": - temp_df_cont = delta_plot_data[ - delta_plot_data[xvar].str.contains( - i[i.index(j) - 1] - ) - ].reset_index(drop=True) - delta_df = temp_df_exp.copy() - delta_df[yvar] = temp_df_exp[yvar] - temp_df_cont[yvar] - final_deltas = pd.concat([final_deltas, delta_df]) - - # swarmplot() plots swarms based on current size of ax - # Therefore, since the ax size for Gardner-Altman plot changes later on, there has to be decreased jitter - # TODO: to make jitter value more accurate and not just a hardcoded eyeball value - if float_contrast: - jitter = 0.6 - else: - jitter = 1 - - # Plot the raw data as a swarmplot. - deltapts_plot = swarmplot( - data=final_deltas, - x=xvar, - y=yvar, - ax=contrast_axes, - order=None, - hue=color_col, - palette=plot_palette_deltapts, - zorder=2, - size=3, - side="right", - jitter=jitter, - is_drop_gutter=True, - gutter_limit=1, - **deltapts_args + slopegraph_plotter( + dabest_obj=dabest_obj, + plot_data=plot_data, + xvar=xvar, + yvar=yvar, + color_col=color_col, + plot_palette_raw=plot_palette_raw, + slopegraph_kwargs=slopegraph_kwargs, + rawdata_axes=rawdata_axes, + ytick_color=ytick_color, + temp_idx=temp_idx ) - contrast_axes.legend().set_visible(False) - ##################### DELTA PTS ON CONTRAST PLOT END + # DELTA PTS ON CONTRAST PLOT WIP + show_delta_dots = plot_kwargs["delta_dot"] + if show_delta_dots and is_paired is not None: + DeltaDotsPlotter( + plot_data=plot_data, + contrast_axes=contrast_axes, + delta_id_col=dabest_obj.id_col, + idx=idx, + xvar=xvar, + yvar=yvar, + is_paired=is_paired, + color_col=color_col, + float_contrast=float_contrast, + plot_palette_raw=plot_palette_raw, + delta_dot_kwargs=delta_dot_kwargs + ) # Set the tick labels, because the slopegraph plotting doesn't. rawdata_axes.set_xticks(np.arange(0, len(temp_all_plot_groups))) @@ -582,1013 +204,378 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs): else: # Plot the raw data as a set of Sankey Diagrams aligned like barplot. - group_summaries = plot_kwargs["group_summaries"] - if group_summaries is None: - group_summaries = "mean_sd" - err_color = plot_kwargs["err_color"] - if err_color is None: - err_color = "black" - - if show_pairs: - sankey_control_group = [] - sankey_test_group = [] - # Design for Sankey Flow Diagram - sankey_idx = ( - [ - (control, test) - for i in idx - for control, test in zip(i[:], (i[1:] + (i[0],))) - ] - if flow - else temp_idx - ) - for i in sankey_idx: - sankey_control_group.append(i[0]) - sankey_test_group.append(i[1]) - - if len(temp_all_plot_groups) == 2: - one_sankey = True - sankey_control_group.pop() - sankey_test_group.pop() # Remove the last element from two lists - - # two_col_sankey = True if proportional == True and one_sankey == False and sankey == True and flow == False else False - - # Replace the paired proportional plot with sankey diagram - sankeyplot = sankeydiag( - plot_data, - xvar=xvar, - yvar=yvar, - left_idx=sankey_control_group, - right_idx=sankey_test_group, - palette=plot_palette_sankey, - ax=rawdata_axes, - one_sankey=one_sankey, - **sankey_kwargs - ) - + sankey_control_group, sankey_test_group = sankeydiag( + plot_data, + xvar=xvar, + yvar=yvar, + temp_all_plot_groups=temp_all_plot_groups, + idx=idx, + temp_idx=temp_idx, + palette=plot_palette_sankey, + ax=rawdata_axes, + **sankey_kwargs + ) else: if not proportional: # Plot the raw data as a swarmplot. asymmetric_side = ( - plot_kwargs["swarm_side"] if plot_kwargs["swarm_side"] is not None else "right" + plot_kwargs["swarm_side"] + if plot_kwargs["swarm_side"] is not None + else "right" ) # Default asymmetric side is right # swarmplot() plots swarms based on current size of ax # Therefore, since the ax size for mini_meta and show_delta changes later on, there has to be increased jitter - # TODO: to make jitter value more accurate and not just a hardcoded eyeball value - if show_mini_meta: - jitter = 1.25 - elif show_delta2: - jitter = 1.4 - else: - jitter = 1 - - if color_col is None: # Determine the use of hue - rawdata_plot = swarmplot( - data=plot_data, - x=xvar, - y=yvar, - ax=rawdata_axes, - order=all_plot_groups, - hue=xvar, - palette=plot_palette_raw, - zorder=1, - side=asymmetric_side, - jitter=jitter, - is_drop_gutter=True, - gutter_limit=0.45, - **swarmplot_kwargs - ) + rawdata_plot, swarm_legend_kwargs = swarmplot( + data=plot_data, + x=xvar, + y=yvar, + ax=rawdata_axes, + order=all_plot_groups, + # hue=xvar if color_col is None else color_col, + hue=color_col, + palette=plot_palette_raw, + zorder=1, + side=asymmetric_side, + jitter=1.25 if show_mini_meta else 1.4 if show_delta2 else 1, # TODO: to make jitter value more accurate and not just a hardcoded eyeball value + filled=filled, + is_drop_gutter=True, + gutter_limit=0.45, + **swarmplot_kwargs + ) + if color_col is None: rawdata_plot.legend().set_visible(False) - else: - rawdata_plot = swarmplot( - data=plot_data, - x=xvar, - y=yvar, - ax=rawdata_axes, - order=all_plot_groups, - hue=color_col, - palette=plot_palette_raw, - zorder=1, - side=asymmetric_side, - jitter=jitter, - is_drop_gutter=True, - gutter_limit=0.45, - **swarmplot_kwargs - ) - else: - # Plot the raw data as a barplot. - bar1_df = pd.DataFrame( - {xvar: all_plot_groups, "proportion": np.ones(len(all_plot_groups))} - ) - bar1 = sns.barplot( - data=bar1_df, - x=xvar, - y="proportion", - ax=rawdata_axes, - order=all_plot_groups, - linewidth=2, - facecolor=(1, 1, 1, 0), - edgecolor=bar_color, - zorder=1, - ) - bar2 = sns.barplot( - data=plot_data, - x=xvar, - y=yvar, - ax=rawdata_axes, - order=all_plot_groups, - palette=plot_palette_bar, - zorder=1, - **barplot_kwargs - ) - # adjust the width of bars - bar_width = plot_kwargs["bar_width"] - for bar in bar1.patches: - x = bar.get_x() - width = bar.get_width() - centre = x + width / 2.0 - bar.set_x(centre - bar_width / 2.0) - bar.set_width(bar_width) - # Plot the gapped line summaries, if this is not a Cumming plot. - # Also, we will not plot gapped lines for paired plots. For now. - group_summaries = plot_kwargs["group_summaries"] - if group_summaries is None: - group_summaries = "mean_sd" - if group_summaries is not None and not proportional: - # Create list to gather xspans. - xspans = [] - line_colors = [] - for jj, c in enumerate(rawdata_axes.collections): - try: - if asymmetric_side == "right": - # currently offset is hardcoded with value of -0.2 - x_max_span = -0.2 - else: - _, x_max, _, _ = get_swarm_spans(c) - x_max_span = x_max - jj - xspans.append(x_max_span) - except TypeError: - # we have got a None, so skip and move on. - pass - - if bootstraps_color_by_group: - line_colors.append(plot_palette_raw[all_plot_groups[jj]]) - - # Break the loop since hue in Seaborn adds collections to axes and it will result in index out of range - if jj >= n_groups - 1 and color_col is None: - break - - if len(line_colors) != len(all_plot_groups): - line_colors = ytick_color - - error_bar( - plot_data, - x=xvar, - y=yvar, - # Hardcoded offset... - offset=xspans + np.array(plot_kwargs["group_summaries_offset"]), - line_color=line_colors, - gap_width_percent=1.5, - type=group_summaries, - ax=rawdata_axes, - method="gapped_lines", - **group_summary_kwargs - ) + else: + # Plot the raw data as a barplot. + barplotter( + xvar=xvar, + yvar=yvar, + all_plot_groups=all_plot_groups, + rawdata_axes=rawdata_axes, + plot_data=plot_data, + bar_color=bar_color, + plot_palette_bar=plot_palette_bar, + plot_kwargs=plot_kwargs, + barplot_kwargs=barplot_kwargs + ) - if group_summaries is not None and proportional: - err_color = plot_kwargs["err_color"] - if err_color is None: - err_color = "black" + # Plot the error bars. + if group_summaries is not None: + if proportional: + group_summaries_method = "proportional_error_bar" + group_summaries_offset = 0 + group_summaries_line_color = err_color + else: + # Create list to gather xspans. + xspans = [] + line_colors = [] + for jj, c in enumerate(rawdata_axes.collections): + try: + if asymmetric_side == "right": + # currently offset is hardcoded with value of -0.2 + x_max_span = -0.2 + else: + _, x_max, _, _ = get_swarm_spans(c) + x_max_span = x_max - jj + xspans.append(x_max_span) + except TypeError: + # we have got a None, so skip and move on. + pass + + if bootstraps_color_by_group: + line_colors.append(plot_palette_raw[all_plot_groups[jj]]) + + # Break the loop since hue in Seaborn adds collections to axes and it will result in index out of range + if jj >= n_groups - 1 and color_col is None: + break + + if len(line_colors) != len(all_plot_groups): + line_colors = ytick_color + + # hue in swarmplot would add collections to axes which will result in len(xspans) = len(all_plot_groups) + len(unique groups in hue) + if len(xspans) > len(all_plot_groups): + xspans = xspans[:len(all_plot_groups)] + + group_summaries_method = "gapped_lines" + group_summaries_offset = xspans + np.array(plot_kwargs["group_summaries_offset"]) + group_summaries_line_color = line_colors + + # Plot error_bar( plot_data, x=xvar, y=yvar, - offset=0, - line_color=err_color, + offset=group_summaries_offset, + line_color=group_summaries_line_color, gap_width_percent=1.5, type=group_summaries, ax=rawdata_axes, - method="proportional_error_bar", + method=group_summaries_method, **group_summary_kwargs - ) + ) # Add the counts to the rawdata axes xticks. - counts = plot_data.groupby(xvar).count()[yvar] - ticks_with_counts = [] - ticks_loc = rawdata_axes.get_xticks() - rawdata_axes.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(ticks_loc)) - for xticklab in rawdata_axes.xaxis.get_ticklabels(): - t = xticklab.get_text() - if t.rfind("\n") != -1: - te = t[t.rfind("\n") + len("\n") :] - N = str(counts.loc[te]) - te = t - else: - te = t - N = str(counts.loc[te]) - - ticks_with_counts.append("{}\nN = {}".format(te, N)) - - if plot_kwargs["fontsize_rawxlabel"] is not None: - fontsize_rawxlabel = plot_kwargs["fontsize_rawxlabel"] - rawdata_axes.set_xticklabels(ticks_with_counts, fontsize=fontsize_rawxlabel) - - # Save the handles and labels for the legend. - handles, labels = rawdata_axes.get_legend_handles_labels() - legend_labels = [l for l in labels] - legend_handles = [h for h in handles] - if bootstraps_color_by_group is False: - rawdata_axes.legend().set_visible(False) + add_counts_to_ticks( + plot_data=plot_data, + xvar=xvar, + yvar=yvar, + rawdata_axes=rawdata_axes, + plot_kwargs=plot_kwargs + ) - # Enforce the xtick of rawdata_axes to be 0 and 1 after drawing only one sankey + # Enforce the xtick of rawdata_axes to be 0 and 1 after drawing only one sankey ----> Redundant code if one_sankey: rawdata_axes.set_xticks([0, 1]) # Plot effect sizes and bootstraps. - # Take note of where the `control` groups are. - if is_paired == "baseline" and show_pairs: - if two_col_sankey: - ticks_to_skip = [] - ticks_to_plot = np.arange(0, len(temp_all_plot_groups) / 2).tolist() - ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist() - ticks_to_start_twocol_sankey.pop() - ticks_to_start_twocol_sankey.insert(0, 0) - else: - # ticks_to_skip = np.arange(0, len(temp_all_plot_groups), 2).tolist() - # ticks_to_plot = np.arange(1, len(temp_all_plot_groups), 2).tolist() - ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist() - ticks_to_skip.insert(0, 0) - # Then obtain the ticks where we have to plot the effect sizes. - ticks_to_plot = [ - t for t in range(0, len(all_plot_groups)) if t not in ticks_to_skip - ] - ticks_to_skip_contrast = np.cumsum([(len(t)) for t in idx])[:-1].tolist() - ticks_to_skip_contrast.insert(0, 0) - else: - if two_col_sankey: - ticks_to_skip = [len(sankey_control_group)] - # Then obtain the ticks where we have to plot the effect sizes. - ticks_to_plot = [ - t for t in range(0, len(temp_idx)) if t not in ticks_to_skip - ] - ticks_to_skip = [] - ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist() - ticks_to_start_twocol_sankey.pop() - ticks_to_start_twocol_sankey.insert(0, 0) - else: - ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist() - ticks_to_skip.insert(0, 0) - # Then obtain the ticks where we have to plot the effect sizes. - ticks_to_plot = [ - t for t in range(0, len(all_plot_groups)) if t not in ticks_to_skip - ] + plot_groups = temp_all_plot_groups if (is_paired == "baseline" and show_pairs and two_col_sankey) else temp_idx if (two_col_sankey) else all_plot_groups + + (ticks_to_skip, ticks_to_plot, + ticks_to_skip_contrast, ticks_to_start_twocol_sankey) = extract_contrast_plotting_ticks( + is_paired=is_paired, + show_pairs=show_pairs, + two_col_sankey=two_col_sankey, + plot_groups=plot_groups, + idx=idx, + sankey_control_group=sankey_control_group if two_col_sankey else None, + ) # Plot the bootstraps, then the effect sizes and CIs. es_marker_size = plot_kwargs["es_marker_size"] halfviolin_alpha = plot_kwargs["halfviolin_alpha"] - ci_type = plot_kwargs["ci_type"] results = effectsize_df.results - contrast_xtick_labels = [] - - for j, tick in enumerate(ticks_to_plot): - current_group = results.test[j] - current_control = results.control[j] - current_bootstrap = results.bootstraps[j] - current_effsize = results.difference[j] - if ci_type == "bca": - current_ci_low = results.bca_low[j] - current_ci_high = results.bca_high[j] - else: - current_ci_low = results.pct_low[j] - current_ci_high = results.pct_high[j] - - # Create the violinplot. - # New in v0.2.6: drop negative infinities before plotting. - v = contrast_axes.violinplot( - current_bootstrap[~np.isinf(current_bootstrap)], - positions=[tick], - **violinplot_kwargs - ) - # Turn the violinplot into half, and color it the same as the swarmplot. - # Do this only if the color column is not specified. - # Ideally, the alpha (transparency) fo the violin plot should be - # less than one so the effect size and CIs are visible. - if bootstraps_color_by_group: - fc = plot_palette_contrast[current_group] - else: - fc = "grey" - - halfviolin(v, fill_color=fc, alpha=halfviolin_alpha) - - # Plot the effect size. - contrast_axes.plot( - [tick], - current_effsize, - marker="o", - color=ytick_color, - markersize=es_marker_size, - ) - - ################## SHOW ES ON CONTRAST PLOT WIP - - contrast_show_es = plot_kwargs["contrast_show_es"] - es_sf = plot_kwargs["es_sf"] - es_fontsize = plot_kwargs["es_fontsize"] - - if gridkey_show_es: - contrast_show_es = False - - effsize_for_print = current_effsize - - printed_es = np.format_float_positional( - effsize_for_print, precision=es_sf, sign=True, trim="k", min_digits=es_sf - ) - if contrast_show_es: - if effsize_for_print < 0: - textoffset = 10 - else: - textoffset = 15 - contrast_axes.annotate( - text=printed_es, - xy=(tick, effsize_for_print), - xytext=( - -textoffset - len(printed_es) * es_fontsize / 2, - -es_fontsize / 2, - ), - textcoords="offset points", - **{"fontsize": es_fontsize} - ) - ################## SHOW ES ON CONTRAST PLOT END - - # Plot the confidence interval. - contrast_axes.plot( - [tick, tick], - [current_ci_low, current_ci_high], - linestyle="-", - color=ytick_color, - linewidth=group_summary_kwargs["lw"], - ) - - contrast_xtick_labels.append( - "{}\nminus\n{}".format(current_group, current_control) - ) + (current_group, current_control, + current_effsize, contrast_xtick_labels) = effect_size_curve_plotter( + ticks_to_plot=ticks_to_plot, + results=results, + ci_type=ci_type, + contrast_axes=contrast_axes, + violinplot_kwargs=violinplot_kwargs, + halfviolin_alpha=halfviolin_alpha, + ytick_color=ytick_color, + es_marker_size=es_marker_size, + group_summary_kwargs=group_summary_kwargs, + bootstraps_color_by_group=bootstraps_color_by_group, + plot_palette_contrast=plot_palette_contrast, + ) # Plot mini-meta violin if show_mini_meta or show_delta2: - if show_mini_meta: - mini_meta_delta = effectsize_df.mini_meta_delta - data = mini_meta_delta.bootstraps_weighted_delta - difference = mini_meta_delta.difference - if ci_type == "bca": - ci_low = mini_meta_delta.bca_low - ci_high = mini_meta_delta.bca_high - else: - ci_low = mini_meta_delta.pct_low - ci_high = mini_meta_delta.pct_high - else: - delta_delta = effectsize_df.delta_delta - data = delta_delta.bootstraps_delta_delta - difference = delta_delta.difference - if ci_type == "bca": - ci_low = delta_delta.bca_low - ci_high = delta_delta.bca_high - else: - ci_low = delta_delta.pct_low - ci_high = delta_delta.pct_high - # Create the violinplot. - # New in v0.2.6: drop negative infinities before plotting. - position = max(rawdata_axes.get_xticks()) + 2 - v = contrast_axes.violinplot( - data[~np.isinf(data)], positions=[position], **violinplot_kwargs - ) - - fc = "grey" - - halfviolin(v, fill_color=fc, alpha=halfviolin_alpha) - - # Plot the effect size. - contrast_axes.plot( - [position], - difference, - marker="o", - color=ytick_color, - markersize=es_marker_size, - ) - # Plot the confidence interval. - contrast_axes.plot( - [position, position], - [ci_low, ci_high], - linestyle="-", - color=ytick_color, - linewidth=group_summary_kwargs["lw"], - ) - if show_mini_meta: - contrast_xtick_labels.extend(["", "Weighted delta"]) - elif effect_size == "delta_g": - contrast_xtick_labels.extend(["", "deltas' g"]) - else: - contrast_xtick_labels.extend(["", "delta-delta"]) + contrast_xtick_labels = plot_minimeta_or_deltadelta_violins( + show_mini_meta=show_mini_meta, + effectsize_df=effectsize_df, + ci_type=ci_type, + rawdata_axes=rawdata_axes, + contrast_axes=contrast_axes, + violinplot_kwargs=violinplot_kwargs, + halfviolin_alpha=halfviolin_alpha, + ytick_color=ytick_color, + es_marker_size=es_marker_size, + group_summary_kwargs=group_summary_kwargs, + contrast_xtick_labels=contrast_xtick_labels, + effect_size=effect_size + ) # Make sure the contrast_axes x-lims match the rawdata_axes xlims, # and add an extra violinplot tick for delta-delta plot. - if show_delta2 is False and show_mini_meta is False: - contrast_axes.set_xticks(rawdata_axes.get_xticks()) - else: - temp = rawdata_axes.get_xticks() - temp = np.append(temp, [max(temp) + 1, max(temp) + 2]) - contrast_axes.set_xticks(temp) - - if show_pairs: - max_x = contrast_axes.get_xlim()[1] - rawdata_axes.set_xlim(-0.375, max_x) - - if float_contrast: - contrast_axes.set_xlim(0.5, 1.5) - elif show_delta2 or show_mini_meta: - # Increase the xlim of raw data by 2 - temp = rawdata_axes.get_xlim() - if show_pairs: - rawdata_axes.set_xlim(temp[0], temp[1] + 0.25) - else: - rawdata_axes.set_xlim(temp[0], temp[1] + 2) - contrast_axes.set_xlim(rawdata_axes.get_xlim()) - else: - contrast_axes.set_xlim(rawdata_axes.get_xlim()) - - # Properly label the contrast ticks. - for t in ticks_to_skip: - contrast_xtick_labels.insert(t, "") - - if plot_kwargs["fontsize_contrastxlabel"] is not None: - fontsize_contrastxlabel = plot_kwargs["fontsize_contrastxlabel"] - - contrast_axes.set_xticklabels( - contrast_xtick_labels, fontsize=fontsize_contrastxlabel - ) + set_xaxis_ticks_and_lims( + show_delta2=show_delta2, + show_mini_meta=show_mini_meta, + rawdata_axes=rawdata_axes, + contrast_axes=contrast_axes, + show_pairs=show_pairs, + float_contrast=float_contrast, + ticks_to_skip=ticks_to_skip, + contrast_xtick_labels=contrast_xtick_labels, + plot_kwargs=plot_kwargs, + ) + # Legend + handles, labels = rawdata_axes.get_legend_handles_labels() + legend_labels = [l for l in labels] + legend_handles = [h for h in handles] if bootstraps_color_by_group is False: - legend_labels_unique = np.unique(legend_labels) - unique_idx = np.unique(legend_labels, return_index=True)[1] - legend_handles_unique = ( - pd.Series(legend_handles, dtype="object").loc[unique_idx] - ).tolist() - - if len(legend_handles_unique) > 0: - if float_contrast: - axes_with_legend = contrast_axes - if show_pairs: - bta = (1.75, 1.02) - else: - bta = (1.5, 1.02) - else: - axes_with_legend = rawdata_axes - if show_pairs: - bta = (1.02, 1.0) - else: - bta = (1.0, 1.0) - leg = axes_with_legend.legend( - legend_handles_unique, - legend_labels_unique, - bbox_to_anchor=bta, - **legend_kwargs + rawdata_axes.legend().set_visible(False) + show_legend( + legend_labels=legend_labels, + legend_handles=legend_handles, + rawdata_axes=rawdata_axes, + contrast_axes=contrast_axes, + float_contrast=float_contrast, + show_pairs=show_pairs, + legend_kwargs=legend_kwargs ) - if show_pairs: - for line in leg.get_lines(): - line.set_linewidth(3.0) + # Add legend for swarmplot + if not show_pairs and not proportional and color_col is not None and not show_delta2: + if len(np.unique(swarm_legend_kwargs['index'])) > 1: + legend_elements = [] + for color, label in zip(swarm_legend_kwargs['colors'], swarm_legend_kwargs['labels']): + legend_elements.append(Line2D([0], [0], marker='o', color='w', label=label, + markerfacecolor=color, markersize=10)) + rawdata_axes.legend(handles=legend_elements, frameon=False) + + # Plot aesthetic adjustments. og_ylim_raw = rawdata_axes.get_ylim() og_xlim_raw = rawdata_axes.get_xlim() if float_contrast: # For Gardner-Altman plots only. - - # Normalize ylims and despine the floating contrast axes. - # Check that the effect size is within the swarm ylims. - if effect_size_type in ["mean_diff", "cohens_d", "hedges_g", "cohens_h"]: - control_group_summary = ( - plot_data.groupby(xvar) - .mean(numeric_only=True) - .loc[current_control, yvar] - ) - test_group_summary = ( - plot_data.groupby(xvar).mean(numeric_only=True).loc[current_group, yvar] - ) - elif effect_size_type == "median_diff": - control_group_summary = ( - plot_data.groupby(xvar).median().loc[current_control, yvar] - ) - test_group_summary = ( - plot_data.groupby(xvar).median().loc[current_group, yvar] - ) - - if swarm_ylim is None: - swarm_ylim = rawdata_axes.get_ylim() - - _, contrast_xlim_max = contrast_axes.get_xlim() - - difference = float(results.difference[0]) - - if effect_size_type in ["mean_diff", "median_diff"]: - # Align 0 of contrast_axes to reference group mean of rawdata_axes. - # If the effect size is positive, shift the contrast axis up. - rawdata_ylims = np.array(rawdata_axes.get_ylim()) - if current_effsize > 0: - rightmin, rightmax = rawdata_ylims - current_effsize - # If the effect size is negative, shift the contrast axis down. - elif current_effsize < 0: - rightmin, rightmax = rawdata_ylims + current_effsize - else: - rightmin, rightmax = rawdata_ylims - - contrast_axes.set_ylim(rightmin, rightmax) - - og_ylim_contrast = rawdata_axes.get_ylim() - np.array(control_group_summary) - - contrast_axes.set_ylim(og_ylim_contrast) - contrast_axes.set_xlim(contrast_xlim_max - 1, contrast_xlim_max) - - elif effect_size_type in ["cohens_d", "hedges_g", "cohens_h"]: - if is_paired: - which_std = 1 - else: - which_std = 0 - temp_control = plot_data[plot_data[xvar] == current_control][yvar] - temp_test = plot_data[plot_data[xvar] == current_group][yvar] - - stds = _compute_standardizers(temp_control, temp_test) - if is_paired: - pooled_sd = stds[1] - else: - pooled_sd = stds[0] - - if effect_size_type == "hedges_g": - gby_count = plot_data.groupby(xvar).count() - len_control = gby_count.loc[current_control, yvar] - len_test = gby_count.loc[current_group, yvar] - - hg_correction_factor = _compute_hedges_correction_factor( - len_control, len_test - ) - - ylim_scale_factor = pooled_sd / hg_correction_factor - - elif effect_size_type == "cohens_h": - ylim_scale_factor = ( - np.mean(temp_test) - np.mean(temp_control) - ) / difference - - else: - ylim_scale_factor = pooled_sd - - scaled_ylim = ( - (rawdata_axes.get_ylim() - control_group_summary) / ylim_scale_factor - ).tolist() - - contrast_axes.set_ylim(scaled_ylim) - og_ylim_contrast = scaled_ylim - - contrast_axes.set_xlim(contrast_xlim_max - 1, contrast_xlim_max) - - if one_sankey is None: - # Draw summary lines for control and test groups.. - for jj, axx in enumerate([rawdata_axes, contrast_axes]): - # Draw effect size line. - if jj == 0: - ref = control_group_summary - diff = test_group_summary - effsize_line_start = 1 - - elif jj == 1: - ref = 0 - diff = ref + difference - effsize_line_start = contrast_xlim_max - 1.1 - - xlimlow, xlimhigh = axx.get_xlim() - - # Draw reference line. - axx.hlines( - ref, # y-coordinates - 0, - xlimhigh, # x-coordinates, start and end. - **reflines_kwargs - ) - - # Draw effect size line. - axx.hlines(diff, effsize_line_start, xlimhigh, **reflines_kwargs) - else: - ref = 0 - diff = ref + difference - effsize_line_start = contrast_xlim_max - 0.9 - xlimlow, xlimhigh = contrast_axes.get_xlim() - # Draw reference line. - contrast_axes.hlines( - ref, # y-coordinates - effsize_line_start, - xlimhigh, # x-coordinates, start and end. - **reflines_kwargs - ) - - # Draw effect size line. - contrast_axes.hlines(diff, effsize_line_start, xlimhigh, **reflines_kwargs) - rawdata_axes.set_xlim(og_xlim_raw) # to align the axis - # Despine appropriately. - sns.despine(ax=rawdata_axes, bottom=True) - sns.despine(ax=contrast_axes, left=True, right=False) - - # Insert break between the rawdata axes and the contrast axes - # by re-drawing the x-spine. - rawdata_axes.hlines( - og_ylim_raw[0], # yindex - rawdata_axes.get_xlim()[0], - 1.3, # xmin, xmax - **redraw_axes_kwargs - ) - rawdata_axes.set_ylim(og_ylim_raw) - - contrast_axes.hlines( - contrast_axes.get_ylim()[0], - contrast_xlim_max - 0.8, - contrast_xlim_max, - **redraw_axes_kwargs - ) + Gardner_Altman_Plot_Aesthetic_Adjustments( + effect_size_type=effect_size_type, + plot_data=plot_data, + xvar=xvar, + yvar=yvar, + current_control=current_control, + current_group=current_group, + rawdata_axes=rawdata_axes, + contrast_axes=contrast_axes, + results=results, + current_effsize=current_effsize, + is_paired=is_paired, + one_sankey=one_sankey, + reflines_kwargs=reflines_kwargs, + redraw_axes_kwargs=redraw_axes_kwargs, + swarm_ylim=swarm_ylim, + og_xlim_raw=og_xlim_raw, + og_ylim_raw=og_ylim_raw, + ) else: # For Cumming Plots only. - - # Set custom contrast_ylim, if it was specified. - if plot_kwargs["contrast_ylim"] is not None or ( - plot_kwargs["delta2_ylim"] is not None and show_delta2 - ): - if plot_kwargs["contrast_ylim"] is not None: - custom_contrast_ylim = plot_kwargs["contrast_ylim"] - if plot_kwargs["delta2_ylim"] is not None and show_delta2: - custom_delta2_ylim = plot_kwargs["delta2_ylim"] - if custom_contrast_ylim != custom_delta2_ylim: - err1 = "Please check if `contrast_ylim` and `delta2_ylim` are assigned" - err2 = "with same values." - raise ValueError(err1 + err2) - else: - custom_delta2_ylim = plot_kwargs["delta2_ylim"] - custom_contrast_ylim = custom_delta2_ylim - - if len(custom_contrast_ylim) != 2: - err1 = "Please check `contrast_ylim` consists of " - err2 = "exactly two numbers." - raise ValueError(err1 + err2) - - if effect_size_type == "cliffs_delta": - # Ensure the ylims for a cliffs_delta plot never exceed [-1, 1]. - l = plot_kwargs["contrast_ylim"][0] - h = plot_kwargs["contrast_ylim"][1] - low = -1 if l < -1 else l - high = 1 if h > 1 else h - contrast_axes.set_ylim(low, high) - else: - contrast_axes.set_ylim(custom_contrast_ylim) - - # If 0 lies within the ylim of the contrast axes, - # draw a zero reference line. - contrast_axes_ylim = contrast_axes.get_ylim() - if contrast_axes_ylim[0] < contrast_axes_ylim[1]: - contrast_ylim_low, contrast_ylim_high = contrast_axes_ylim - else: - contrast_ylim_high, contrast_ylim_low = contrast_axes_ylim - if contrast_ylim_low < 0 < contrast_ylim_high: - contrast_axes.axhline(y=0, **reflines_kwargs) - - if is_paired == "baseline" and show_pairs: - if two_col_sankey: - rightend_ticks_raw = np.array([len(i) - 2 for i in idx]) + np.array( - ticks_to_start_twocol_sankey - ) - elif proportional and is_paired is not None: - rightend_ticks_raw = np.array([len(i) - 1 for i in idx]) + np.array( - ticks_to_skip - ) - else: - rightend_ticks_raw = np.array( - [len(i) - 1 for i in temp_idx] - ) + np.array(ticks_to_skip) - for ax in [rawdata_axes]: - sns.despine(ax=ax, bottom=True) - - ylim = ax.get_ylim() - xlim = ax.get_xlim() - redraw_axes_kwargs["y"] = ylim[0] - - if two_col_sankey: - for k, start_tick in enumerate(ticks_to_start_twocol_sankey): - end_tick = rightend_ticks_raw[k] - ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs) - else: - for k, start_tick in enumerate(ticks_to_skip): - end_tick = rightend_ticks_raw[k] - ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs) - ax.set_ylim(ylim) - del redraw_axes_kwargs["y"] - - if not proportional: - temp_length = [(len(i) - 1) for i in idx] - else: - temp_length = [(len(i) - 1) * 2 - 1 for i in idx] - if two_col_sankey: - rightend_ticks_contrast = np.array( - [len(i) - 2 for i in idx] - ) + np.array(ticks_to_start_twocol_sankey) - elif proportional and is_paired is not None: - rightend_ticks_contrast = np.array( - [len(i) - 1 for i in idx] - ) + np.array(ticks_to_skip) - else: - rightend_ticks_contrast = np.array(temp_length) + np.array( - ticks_to_skip_contrast - ) - for ax in [contrast_axes]: - sns.despine(ax=ax, bottom=True) - - ylim = ax.get_ylim() - xlim = ax.get_xlim() - redraw_axes_kwargs["y"] = ylim[0] - - if two_col_sankey: - for k, start_tick in enumerate(ticks_to_start_twocol_sankey): - end_tick = rightend_ticks_contrast[k] - ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs) - else: - for k, start_tick in enumerate(ticks_to_skip_contrast): - end_tick = rightend_ticks_contrast[k] - ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs) - - ax.set_ylim(ylim) - del redraw_axes_kwargs["y"] - else: - # Compute the end of each x-axes line. - if two_col_sankey: - rightend_ticks = np.array([len(i) - 2 for i in idx]) + np.array( - ticks_to_start_twocol_sankey - ) - else: - rightend_ticks = np.array([len(i) - 1 for i in idx]) + np.array( - ticks_to_skip - ) - - for ax in [rawdata_axes, contrast_axes]: - sns.despine(ax=ax, bottom=True) - - ylim = ax.get_ylim() - xlim = ax.get_xlim() - redraw_axes_kwargs["y"] = ylim[0] - - if two_col_sankey: - for k, start_tick in enumerate(ticks_to_start_twocol_sankey): - end_tick = rightend_ticks[k] - ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs) - else: - for k, start_tick in enumerate(ticks_to_skip): - end_tick = rightend_ticks[k] - ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs) - - ax.set_ylim(ylim) - del redraw_axes_kwargs["y"] - - if show_delta2 or show_mini_meta: - ylim = contrast_axes.get_ylim() - redraw_axes_kwargs["y"] = ylim[0] - x_ticks = contrast_axes.get_xticks() - contrast_axes.hlines(xmin=x_ticks[-2], xmax=x_ticks[-1], **redraw_axes_kwargs) - del redraw_axes_kwargs["y"] - - # Set raw axes y-label. - swarm_label = plot_kwargs["swarm_label"] - if swarm_label is None and yvar is None: - swarm_label = "value" - elif swarm_label is None and yvar is not None: - swarm_label = yvar - - bar_label = plot_kwargs["bar_label"] - if bar_label is None and effect_size_type != "cohens_h": - bar_label = "proportion of success" - elif bar_label is None and effect_size_type == "cohens_h": - bar_label = "value" - - # Place contrast axes y-label. - contrast_label_dict = { - "mean_diff": "mean difference", - "median_diff": "median difference", - "cohens_d": "Cohen's d", - "hedges_g": "Hedges' g", - "cliffs_delta": "Cliff's delta", - "cohens_h": "Cohen's h", - "delta_g": "mean difference", - } - - if proportional and effect_size_type != "cohens_h": - default_contrast_label = "proportion difference" - elif effect_size_type == "delta_g": - default_contrast_label = "Hedges' g" - else: - default_contrast_label = contrast_label_dict[effectsize_df.effect_size] - - if plot_kwargs["contrast_label"] is None: - if is_paired: - contrast_label = "paired\n{}".format(default_contrast_label) - else: - contrast_label = default_contrast_label - contrast_label = contrast_label.capitalize() - else: - contrast_label = plot_kwargs["contrast_label"] - - if plot_kwargs["fontsize_rawylabel"] is not None: - fontsize_rawylabel = plot_kwargs["fontsize_rawylabel"] - if plot_kwargs["fontsize_contrastylabel"] is not None: - fontsize_contrastylabel = plot_kwargs["fontsize_contrastylabel"] - if plot_kwargs["fontsize_delta2label"] is not None: - fontsize_delta2label = plot_kwargs["fontsize_delta2label"] - - contrast_axes.set_ylabel(contrast_label, fontsize=fontsize_contrastylabel) - if float_contrast: - contrast_axes.yaxis.set_label_position("right") - - # Set the rawdata axes labels appropriately - if not proportional: - rawdata_axes.set_ylabel(swarm_label, fontsize=fontsize_rawylabel) - else: - rawdata_axes.set_ylabel(bar_label, fontsize=fontsize_rawylabel) - rawdata_axes.set_xlabel("") - - # Because we turned the axes frame off, we also need to draw back - # the y-spine for both axes. - if not float_contrast: - rawdata_axes.set_xlim(contrast_axes.get_xlim()) - og_xlim_raw = rawdata_axes.get_xlim() - rawdata_axes.vlines( - og_xlim_raw[0], og_ylim_raw[0], og_ylim_raw[1], **redraw_axes_kwargs - ) - - og_xlim_contrast = contrast_axes.get_xlim() - - if float_contrast: - xpos = og_xlim_contrast[1] - else: - xpos = og_xlim_contrast[0] - - og_ylim_contrast = contrast_axes.get_ylim() - contrast_axes.vlines( - xpos, og_ylim_contrast[0], og_ylim_contrast[1], **redraw_axes_kwargs - ) - - if show_delta2: - if plot_kwargs["delta2_label"] is not None: - delta2_label = plot_kwargs["delta2_label"] - elif effect_size == "mean_diff": - delta2_label = "delta - delta" - else: - delta2_label = "deltas' g" - delta2_axes = contrast_axes.twinx() - delta2_axes.set_frame_on(False) - delta2_axes.set_ylabel(delta2_label, fontsize=fontsize_delta2label) - og_xlim_delta = contrast_axes.get_xlim() - og_ylim_delta = contrast_axes.get_ylim() - delta2_axes.set_ylim(og_ylim_delta) - delta2_axes.vlines( - og_xlim_delta[1], og_ylim_delta[0], og_ylim_delta[1], **redraw_axes_kwargs - ) - - ################################################### GRIDKEY MAIN CODE WIP - + Cumming_Plot_Aesthetic_Adjustments( + plot_kwargs=plot_kwargs, + show_delta2=show_delta2, + effect_size_type=effect_size_type, + contrast_axes=contrast_axes, + reflines_kwargs=reflines_kwargs, + is_paired=is_paired, + show_pairs=show_pairs, + two_col_sankey=two_col_sankey, + idx=idx, + ticks_to_start_twocol_sankey=ticks_to_start_twocol_sankey, + proportional=proportional, + ticks_to_skip=ticks_to_skip, + temp_idx=temp_idx if is_paired == "baseline" and show_pairs else None, + rawdata_axes=rawdata_axes, + redraw_axes_kwargs=redraw_axes_kwargs, + ticks_to_skip_contrast=ticks_to_skip_contrast, + ) + + # General plotting changes + General_Plot_Aesthetic_Adjustments( + show_delta2=show_delta2, + show_mini_meta=show_mini_meta, + contrast_axes=contrast_axes, + redraw_axes_kwargs=redraw_axes_kwargs, + plot_kwargs=plot_kwargs, + yvar=yvar, + effect_size_type=effect_size_type, + proportional=proportional, + effectsize_df=effectsize_df, + is_paired=is_paired, + float_contrast=float_contrast, + rawdata_axes=rawdata_axes, + og_ylim_raw=og_ylim_raw, + effect_size=effect_size, + ) + + ################################################### GRIDKEY WIP # if gridkey_rows is None, skip everything here + gridkey_rows = plot_kwargs["gridkey_rows"] if gridkey_rows is not None: - # Raise error if there are more than 2 items in any idx and gridkey_merge_pairs is True and is_paired is not None - if gridkey_merge_pairs and is_paired is not None: - for i in idx: - if len(i) > 2: - warnings.warn( - "gridkey_merge_pairs=True only works if all idx in tuples have only two items. gridkey_merge_pairs has automatically been set to False" - ) - gridkey_merge_pairs = False - break - elif gridkey_merge_pairs and is_paired is None: - warnings.warn( - "gridkey_merge_pairs=True is only applicable for paired data." - ) - gridkey_merge_pairs = False - - # Checks for gridkey_merge_pairs and is_paired; if both are true, "merges" the gridkey per pair - if gridkey_merge_pairs and is_paired is not None: - groups_for_gridkey = [] - for i in idx: - groups_for_gridkey.append(i[1]) - else: - groups_for_gridkey = all_plot_groups - - # raise errors if gridkey_rows is not a list, or if the list is empty - if isinstance(gridkey_rows, list) is False: - raise TypeError("gridkey_rows must be a list.") - elif len(gridkey_rows) == 0: - warnings.warn("gridkey_rows is an empty list.") + grid_key_WIP( + is_paired=is_paired, + idx=idx, + all_plot_groups=all_plot_groups, + gridkey_rows=gridkey_rows, + rawdata_axes=rawdata_axes, + contrast_axes=contrast_axes, + plot_data=plot_data, + xvar=xvar, + yvar=yvar, + results=results, + show_delta2=show_delta2, + show_mini_meta=show_mini_meta, + float_contrast=float_contrast, + plot_kwargs=plot_kwargs, + ) - # raise Warning if an item in gridkey_rows is not contained in any idx - for i in gridkey_rows: - in_idx = 0 - for j in groups_for_gridkey: - if i in j: - in_idx += 1 - if in_idx == 0: - if is_paired is not None: - warnings.warn( - i - + " is not in any idx. Please check. Alternatively, merging gridkey pairs may not be suitable for your data; try passing gridkey_merge_pairs=False." + ################################################### Swarm & Contrast & Summary Bars & Delta text WIP + # Swarm bars WIP + swarm_bars = plot_kwargs["swarm_bars"] + if swarm_bars and not proportional: + swarm_bars_plotter( + plot_data=plot_data, + xvar=xvar, + yvar=yvar, + ax=rawdata_axes, + swarm_bars_kwargs=swarm_bars_kwargs, + color_col=color_col, + plot_palette_raw=plot_palette_raw, + is_paired=is_paired ) - else: - warnings.warn(i + " is not in any idx. Please check.") - - # Populate table: checks if idx for each column contains rowlabel name - # IF so, marks that element as present w black dot, or space if not present - table_cellcols = [] - for i in gridkey_rows: - thisrow = [] - for q in groups_for_gridkey: - if str(i) in q: - thisrow.append("\u25CF") - else: - thisrow.append("") - table_cellcols.append(thisrow) - - # Adds a row for Ns with the Ns values - if gridkey_show_Ns: - gridkey_rows.append("Ns") - list_of_Ns = [] - for i in groups_for_gridkey: - list_of_Ns.append(str(counts.loc[i])) - table_cellcols.append(list_of_Ns) - # Adds a row for effectsizes with effectsize values - if gridkey_show_es: - gridkey_rows.append("\u0394") - effsize_list = [] - results_list = results.test.to_list() - - # get the effect size, append + or -, 2 dec places - for i in enumerate(groups_for_gridkey): - if i[1] in results_list: - curr_esval = results.loc[results["test"] == i[1]][ - "difference" - ].iloc[0] - curr_esval_str = np.format_float_positional( - curr_esval, - precision=es_sf, - sign=True, - trim="k", - min_digits=es_sf, + # Contrast bars WIP + contrast_bars = plot_kwargs["contrast_bars"] + if contrast_bars: + contrast_bars_plotter( + results=results, + ax_to_plot=contrast_axes, + swarm_plot_ax=rawdata_axes, + ticks_to_plot=ticks_to_plot, + contrast_bars_kwargs=contrast_bars_kwargs, + color_col=color_col, + plot_palette_raw=plot_palette_raw, + show_mini_meta=show_mini_meta, + mini_meta_delta=effectsize_df.mini_meta_delta if show_mini_meta else None, + show_delta2=show_delta2, + delta_delta=effectsize_df.delta_delta if show_delta2 else None, + proportional=proportional, + is_paired=is_paired + ) + + # Summary bars WIP + summary_bars = plot_kwargs["summary_bars"] + if summary_bars is not None: + summary_bars_plotter( + summary_bars=summary_bars, + results=results, + ax_to_plot=contrast_axes, + float_contrast=float_contrast, + summary_bars_kwargs=summary_bars_kwargs, + ci_type=ci_type, + ticks_to_plot=ticks_to_plot, + color_col=color_col, + swarm_colors=swarm_colors, + proportional=proportional, + is_paired=is_paired + ) + # Delta text WIP + delta_text = plot_kwargs["delta_text"] + if delta_text: + delta_text_plotter( + results=results, + ax_to_plot=contrast_axes, + swarm_plot_ax=rawdata_axes, + ticks_to_plot=ticks_to_plot, + delta_text_kwargs=delta_text_kwargs, + color_col=color_col, + swarm_colors=swarm_colors, + is_paired=is_paired, + proportional=proportional, + float_contrast=float_contrast, + show_mini_meta=show_mini_meta, + mini_meta_delta=effectsize_df.mini_meta_delta if show_mini_meta else None, + show_delta2=show_delta2, + delta_delta=effectsize_df.delta_delta if show_delta2 else None ) - effsize_list.append(curr_esval_str) - else: - effsize_list.append("-") - - table_cellcols.append(effsize_list) - - # If Gardner-Altman plot, plot on raw data and not contrast axes - if float_contrast: - axes_ploton = rawdata_axes - else: - axes_ploton = contrast_axes - - # Account for extended x axis in case of show_delta2 or show_mini_meta - x_groups_for_width = len(groups_for_gridkey) - if show_delta2 or show_mini_meta: - x_groups_for_width += 2 - gridkey_width = len(groups_for_gridkey) / x_groups_for_width - - gridkey = axes_ploton.table( - cellText=table_cellcols, - rowLabels=gridkey_rows, - cellLoc="center", - bbox=[ - 0, - -len(gridkey_rows) * 0.1 - 0.05, - gridkey_width, - len(gridkey_rows) * 0.1, - ], - **{"alpha": 0.5} - ) - - # modifies row label cells - for cell in gridkey._cells: - if cell[1] == -1: - gridkey._cells[cell].visible_edges = "open" - gridkey._cells[cell].set_text_props(**{"ha": "right"}) - - # turns off both x axes - rawdata_axes.get_xaxis().set_visible(False) - contrast_axes.get_xaxis().set_visible(False) - - ####################################################### END GRIDKEY MAIN CODE WIP + ################################################### Swarm & Contrast & Summary Bars & Delta text WIP END # Make sure no stray ticks appear! rawdata_axes.xaxis.set_ticks_position("bottom") @@ -1602,5 +589,5 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs): plt.rcParams[parameter] = original_rcParams[parameter] # Return the figure. + fig.show() return fig - diff --git a/nbs/API/effsize_objects.ipynb b/nbs/API/effsize_objects.ipynb index fd8496c1..4d0bde93 100644 --- a/nbs/API/effsize_objects.ipynb +++ b/nbs/API/effsize_objects.ipynb @@ -1132,6 +1132,7 @@ " contrast_ylim=None,\n", " delta2_ylim=None,\n", " swarm_side=None,\n", + " empty_circle=False,\n", " custom_palette=None,\n", " swarm_desat=0.5,\n", " halfviolin_desat=1,\n", @@ -1155,10 +1156,6 @@ " fig_size=None,\n", " dpi=100,\n", " ax=None,\n", - " contrast_show_es=False,\n", - " es_sf=2,\n", - " es_fontsize=10,\n", - " contrast_show_deltas=True,\n", " gridkey_rows=None,\n", " gridkey_merge_pairs=False,\n", " gridkey_show_Ns=True,\n", @@ -1178,6 +1175,17 @@ " fontsize_contrastxlabel=12,\n", " fontsize_contrastylabel=12,\n", " fontsize_delta2label=12,\n", + " #### Contrast bars and delta text and delta dots WIP ####\n", + " contrast_bars=True,\n", + " swarm_bars=True,\n", + " contrast_bars_kwargs=None,\n", + " swarm_bars_kwargs=None,\n", + " summary_bars=None,\n", + " summary_bars_kwargs=None,\n", + " delta_text=True,\n", + " delta_text_kwargs=None,\n", + " delta_dot=True,\n", + " delta_dot_kwargs=None,\n", " ):\n", " \"\"\"\n", " Creates an estimation plot for the effect size of interest.\n", @@ -1225,6 +1233,12 @@ " https://seaborn.pydata.org/generated/seaborn.cubehelix_palette.html\n", " The named colors of matplotlib can be found here:\n", " https://matplotlib.org/examples/color/named_colors.html\n", + " swarm_side: string, default None\n", + " The side on which points are swarmed for swarmplots (\"center\", \"left\", or \"right\").\n", + " empty_circle: boolean, default False\n", + " Boolean value determining if empty circles will be used for plotting of\n", + " swarmplot for control groups. Color of each individual swarm is also now\n", + " dependent on the comparison group.\n", " swarm_desat : float, default 1\n", " Decreases the saturation of the colors in the swarmplot by the\n", " desired proportion. Uses `seaborn.desaturate()` to acheive this.\n", @@ -1320,7 +1334,39 @@ " Font size for the contrast axes ylabel.\n", " fontsize_delta2label : float, default 12\n", " Font size for the delta-delta axes ylabel.\n", - "\n", + " \n", + " \n", + " contrast_bars : boolean, default True\n", + " Whether or not to display the contrast bars.\n", + " swarm_bars : boolean, default True\n", + " Whether or not to display the swarm bars.\n", + " contrast_bars_kwargs : dict, default None\n", + " Pass relevant keyword arguments to the contrast bars. Pass any keyword arguments accepted by \n", + " matplotlib.patches.Rectangle here, as a string. If None, the following keywords are passed:\n", + " {\"color\": None, \"alpha\": 0.3}\n", + " swarm_bars_kwargs : dict, default None\n", + " Pass relevant keyword arguments to the swarm bars. Pass any keyword arguments accepted by \n", + " matplotlib.patches.Rectangle here, as a string. If None, the following keywords are passed:\n", + " {\"color\": None, \"alpha\": 0.3}\n", + "\n", + " summary_bars : list, default None\n", + " Pass a list of indices of the contrast objects to have summary bars displayed on the plot.\n", + " For example, [0,1] will show summary bars for the first two contrast objects.\n", + " summary_bars_kwargs: dict, default None\n", + " If None, the following keywords are passed: {\"color\": None, \"alpha\": 0.15}\n", + " delta_text : boolean, default True\n", + " Whether or not to display the text deltas.\n", + " delta_text_kwargs : dict, default None\n", + " Pass relevant keyword arguments to the delta text. Pass any keyword arguments accepted by\n", + " matplotlib.text.Text here, as a string. If None, the following keywords are passed:\n", + " {\"color\": None, \"alpha\": 1, \"fontsize\": 10, \"ha\": 'center', \"va\": 'center', \"rotation\": 0, \n", + " \"x_location\": 'right', \"x_coordinates\": None, \"y_coordinates\": None}\n", + " Use \"x_coordinates\" and \"y_coordinates\" if you would like to specify the text locations manually.\n", + " delta_dot : boolean, default True\n", + " Whether or not to display the delta dots on paired or repeated measure plots.\n", + " delta_dot_kwargs : dict, default None\n", + " Pass relevant keyword arguments. If None, the following keywords are passed:\n", + " {\"marker\": \"^\", \"alpha\": 0.5, \"zorder\": 2, \"size\": 3, \"side\": \"right\"}\n", "\n", " Returns\n", " -------\n", @@ -1341,7 +1387,7 @@ " if hasattr(self, \"results\") is False:\n", " self.__pre_calc()\n", "\n", - " if self.__delta2:\n", + " if self.__delta2 and not empty_circle:\n", " color_col = self.__x2\n", "\n", " # if self.__proportional:\n", @@ -1921,7 +1967,13 @@ "source": [] } ], - "metadata": {}, + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, "nbformat": 4, "nbformat_minor": 2 } diff --git a/nbs/API/misc_tools.ipynb b/nbs/API/misc_tools.ipynb index 0395a57c..dd9fbdd0 100644 --- a/nbs/API/misc_tools.ipynb +++ b/nbs/API/misc_tools.ipynb @@ -55,7 +55,12 @@ "source": [ "#| export\n", "import datetime as dt\n", - "from numpy import repeat" + "import numpy as np\n", + "from numpy import repeat\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib" ] }, { @@ -125,7 +130,1026 @@ " matching_vars = [k for k, v in globals().items() if v is obj]\n", " if len(matching_vars) > 0:\n", " return matching_vars[0]\n", - " return \"\"" + " return \"\"\n", + "\n", + "def get_params(effectsize_df, plot_kwargs):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " effectsize_df : object (Dataframe)\n", + " A `dabest` EffectSizeDataFrame object.\n", + " plot_kwargs : dict\n", + " Kwargs passed to the plot function.\n", + " \"\"\"\n", + " dabest_obj = effectsize_df.dabest_obj\n", + " plot_data = effectsize_df._plot_data\n", + " xvar = effectsize_df.xvar\n", + " yvar = effectsize_df.yvar\n", + " is_paired = effectsize_df.is_paired\n", + " delta2 = effectsize_df.delta2\n", + " mini_meta = effectsize_df.mini_meta\n", + " effect_size = effectsize_df.effect_size\n", + " proportional = effectsize_df.proportional\n", + " all_plot_groups = dabest_obj._all_plot_groups\n", + " idx = dabest_obj.idx\n", + "\n", + " if effect_size not in [\"mean_diff\", \"delta_g\"] or not delta2:\n", + " show_delta2 = False\n", + " else:\n", + " show_delta2 = plot_kwargs[\"show_delta2\"]\n", + "\n", + " if effect_size != \"mean_diff\" or not mini_meta:\n", + " show_mini_meta = False\n", + " else:\n", + " show_mini_meta = plot_kwargs[\"show_mini_meta\"]\n", + "\n", + " if show_delta2 and show_mini_meta: raise ValueError(\"`show_delta2` and `show_mini_meta` cannot be True at the same time.\")\n", + "\n", + " # Disable Gardner-Altman plotting if any of the idxs comprise of more than\n", + " # two groups or if it is a delta-delta plot.\n", + " float_contrast = plot_kwargs[\"float_contrast\"]\n", + " effect_size_type = effectsize_df.effect_size\n", + " if len(idx) > 1 or len(idx[0]) > 2:\n", + " float_contrast = False\n", + "\n", + " if effect_size_type in [\"cliffs_delta\"]:\n", + " float_contrast = False\n", + "\n", + " if show_delta2 or show_mini_meta:\n", + " float_contrast = False\n", + "\n", + " if not is_paired:\n", + " show_pairs = False\n", + " else:\n", + " show_pairs = plot_kwargs[\"show_pairs\"]\n", + "\n", + " # Group summaries\n", + " group_summaries = plot_kwargs[\"group_summaries\"]\n", + " if group_summaries is None:\n", + " group_summaries = \"mean_sd\"\n", + "\n", + " # Error bar color\n", + " err_color = plot_kwargs[\"err_color\"]\n", + " if err_color is None: \n", + " err_color = \"black\"\n", + " \n", + " return (dabest_obj, plot_data, xvar, yvar, is_paired, effect_size, proportional, all_plot_groups, idx, \n", + " show_delta2, show_mini_meta, float_contrast, show_pairs, effect_size_type, group_summaries, err_color)\n", + "\n", + "def get_kwargs(plot_kwargs, ytick_color):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " plot_kwargs : dict\n", + " Kwargs passed to the plot function.\n", + " ytick_color : str\n", + " Color of the yticks.\n", + " \"\"\"\n", + " from .misc_tools import merge_two_dicts\n", + "\n", + " # Swarmplot kwargs\n", + " default_swarmplot_kwargs = {\"size\": plot_kwargs[\"raw_marker_size\"]}\n", + " if plot_kwargs[\"swarmplot_kwargs\"] is None:\n", + " swarmplot_kwargs = default_swarmplot_kwargs\n", + " else:\n", + " swarmplot_kwargs = merge_two_dicts(\n", + " default_swarmplot_kwargs, plot_kwargs[\"swarmplot_kwargs\"]\n", + " )\n", + "\n", + " # Barplot kwargs\n", + " default_barplot_kwargs = {\"estimator\": np.mean, \"errorbar\": plot_kwargs[\"ci\"]}\n", + " if plot_kwargs[\"barplot_kwargs\"] is None:\n", + " barplot_kwargs = default_barplot_kwargs\n", + " else:\n", + " barplot_kwargs = merge_two_dicts(\n", + " default_barplot_kwargs, plot_kwargs[\"barplot_kwargs\"]\n", + " )\n", + "\n", + " # Sankey Diagram kwargs\n", + " default_sankey_kwargs = {\n", + " \"width\": 0.4,\n", + " \"align\": \"center\",\n", + " \"sankey\": True,\n", + " \"flow\": True,\n", + " \"alpha\": 0.4,\n", + " \"rightColor\": False,\n", + " \"bar_width\": 0.2,\n", + " }\n", + " if plot_kwargs[\"sankey_kwargs\"] is None:\n", + " sankey_kwargs = default_sankey_kwargs\n", + " else:\n", + " sankey_kwargs = merge_two_dicts(\n", + " default_sankey_kwargs, plot_kwargs[\"sankey_kwargs\"]\n", + " )\n", + "\n", + " # Violinplot kwargs.\n", + " default_violinplot_kwargs = {\n", + " \"widths\": 0.5,\n", + " \"vert\": True,\n", + " \"showextrema\": False,\n", + " \"showmedians\": False,\n", + " }\n", + " if plot_kwargs[\"violinplot_kwargs\"] is None:\n", + " violinplot_kwargs = default_violinplot_kwargs\n", + " else:\n", + " violinplot_kwargs = merge_two_dicts(\n", + " default_violinplot_kwargs, plot_kwargs[\"violinplot_kwargs\"]\n", + " )\n", + "\n", + " # Slopegraph kwargs.\n", + " default_slopegraph_kwargs = {\"linewidth\": 1, \"alpha\": 0.5}\n", + " if plot_kwargs[\"slopegraph_kwargs\"] is None:\n", + " slopegraph_kwargs = default_slopegraph_kwargs\n", + " else:\n", + " slopegraph_kwargs = merge_two_dicts(\n", + " default_slopegraph_kwargs, plot_kwargs[\"slopegraph_kwargs\"]\n", + " )\n", + "\n", + " # Zero reference-line kwargs.\n", + " default_reflines_kwargs = {\n", + " \"linestyle\": \"solid\",\n", + " \"linewidth\": 0.75,\n", + " \"zorder\": 2,\n", + " \"color\": ytick_color,\n", + " }\n", + " if plot_kwargs[\"reflines_kwargs\"] is None:\n", + " reflines_kwargs = default_reflines_kwargs\n", + " else:\n", + " reflines_kwargs = merge_two_dicts(\n", + " default_reflines_kwargs, plot_kwargs[\"reflines_kwargs\"]\n", + " )\n", + "\n", + " # Legend kwargs.\n", + " default_legend_kwargs = {\"loc\": \"upper left\", \"frameon\": False}\n", + " if plot_kwargs[\"legend_kwargs\"] is None:\n", + " legend_kwargs = default_legend_kwargs\n", + " else:\n", + " legend_kwargs = merge_two_dicts(\n", + " default_legend_kwargs, plot_kwargs[\"legend_kwargs\"]\n", + " )\n", + "\n", + " # Group summaries kwargs.\n", + " gs_default = {\"mean_sd\", \"median_quartiles\", None}\n", + " if plot_kwargs[\"group_summaries\"] not in gs_default:\n", + " raise ValueError(\n", + " \"group_summaries must be one of\" \" these: {}.\".format(gs_default)\n", + " )\n", + "\n", + " default_group_summary_kwargs = {\"zorder\": 3, \"lw\": 2, \"alpha\": 1}\n", + " if plot_kwargs[\"group_summary_kwargs\"] is None:\n", + " group_summary_kwargs = default_group_summary_kwargs\n", + " else:\n", + " group_summary_kwargs = merge_two_dicts(\n", + " default_group_summary_kwargs, plot_kwargs[\"group_summary_kwargs\"]\n", + " )\n", + "\n", + " # Redraw axes kwargs.\n", + " redraw_axes_kwargs = {\n", + " \"colors\": ytick_color,\n", + " \"facecolors\": ytick_color,\n", + " \"lw\": 1,\n", + " \"zorder\": 10,\n", + " \"clip_on\": False,\n", + " }\n", + " \n", + " # Delta dots kwargs.\n", + " default_delta_dot_kwargs = {\"marker\": \"^\", \"alpha\": 0.5, \"zorder\": 2, \"size\": 3, \"side\": \"right\"}\n", + " if plot_kwargs[\"delta_dot_kwargs\"] is None:\n", + " delta_dot_kwargs = default_delta_dot_kwargs\n", + " else:\n", + " delta_dot_kwargs = merge_two_dicts(default_delta_dot_kwargs, plot_kwargs[\"delta_dot_kwargs\"])\n", + "\n", + " # Delta text kwargs.\n", + " default_delta_text_kwargs = {\"color\": None, \"alpha\": 1, \"fontsize\": 10, \"ha\": 'center', \"va\": 'center', \"rotation\": 0, \"x_location\": 'right', \"x_coordinates\": None, \"y_coordinates\": None}\n", + " if plot_kwargs[\"delta_text_kwargs\"] is None:\n", + " delta_text_kwargs = default_delta_text_kwargs\n", + " else:\n", + " delta_text_kwargs = merge_two_dicts(default_delta_text_kwargs, plot_kwargs[\"delta_text_kwargs\"])\n", + "\n", + " # Summary bars kwargs.\n", + " default_summary_bars_kwargs = {\"color\": None, \"alpha\": 0.15}\n", + " if plot_kwargs[\"summary_bars_kwargs\"] is None:\n", + " summary_bars_kwargs = default_summary_bars_kwargs\n", + " else:\n", + " summary_bars_kwargs = merge_two_dicts(default_summary_bars_kwargs, plot_kwargs[\"summary_bars_kwargs\"])\n", + "\n", + " # Swarm bars kwargs.\n", + " default_swarm_bars_kwargs = {\"color\": None, \"alpha\": 0.3}\n", + " if plot_kwargs[\"swarm_bars_kwargs\"] is None:\n", + " swarm_bars_kwargs = default_swarm_bars_kwargs\n", + " else:\n", + " swarm_bars_kwargs = merge_two_dicts(default_swarm_bars_kwargs, plot_kwargs[\"swarm_bars_kwargs\"])\n", + "\n", + " # Contrast bars kwargs.\n", + " default_contrast_bars_kwargs = {\"color\": None, \"alpha\": 0.3}\n", + " if plot_kwargs[\"contrast_bars_kwargs\"] is None:\n", + " contrast_bars_kwargs = default_contrast_bars_kwargs\n", + " else:\n", + " contrast_bars_kwargs = merge_two_dicts(default_contrast_bars_kwargs, plot_kwargs[\"contrast_bars_kwargs\"])\n", + "\n", + " return (swarmplot_kwargs, barplot_kwargs, sankey_kwargs, violinplot_kwargs, slopegraph_kwargs, \n", + " reflines_kwargs, legend_kwargs, group_summary_kwargs, redraw_axes_kwargs, delta_dot_kwargs,\n", + " delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs)\n", + "\n", + "\n", + "def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx):\n", + "\n", + " # Create color palette that will be shared across subplots.\n", + " color_col = plot_kwargs[\"color_col\"]\n", + " if color_col is None:\n", + " color_groups = pd.unique(plot_data[xvar])\n", + " bootstraps_color_by_group = True\n", + " else:\n", + " if color_col not in plot_data.columns:\n", + " raise KeyError(\"``{}`` is not a column in the data.\".format(color_col))\n", + " color_groups = pd.unique(plot_data[color_col])\n", + " bootstraps_color_by_group = False\n", + " if show_pairs:\n", + " bootstraps_color_by_group = False\n", + " # Handle the color palette.\n", + " filled = True\n", + " empty_circle = plot_kwargs[\"empty_circle\"]\n", + " color_by_subgroups = (\n", + " True if empty_circle else False\n", + " ) # boolean flag to determine if colour is being grouped by subgroup or the default\n", + " if empty_circle:\n", + " # Handling color_by_subgroups\n", + " # For now, color_by_subgroups can only be True for multi-2-group and 2-group comparison\n", + " if isinstance(idx[0], str):\n", + " if len(idx) > 2:\n", + " color_by_subgroups = False\n", + " else:\n", + " for group_i in idx:\n", + " if len(group_i) > 2:\n", + " color_by_subgroups = False\n", + "\n", + " # filled is now a list, which determines the which group in idx has their dots filled for the swarmplot\n", + " filled = []\n", + " for i in range(len(idx)):\n", + " filled.append(False)\n", + " filled.extend([True] * (len(idx[i]) - 1))\n", + "\n", + " names = color_groups if not color_by_subgroups else idx\n", + " n_groups = len(color_groups)\n", + " custom_pal = plot_kwargs[\"custom_palette\"]\n", + " swarm_desat = plot_kwargs[\"swarm_desat\"]\n", + " bar_desat = plot_kwargs[\"bar_desat\"]\n", + " contrast_desat = plot_kwargs[\"halfviolin_desat\"]\n", + "\n", + " if custom_pal is None:\n", + " unsat_colors = sns.color_palette(n_colors=n_groups)\n", + " if empty_circle and not color_by_subgroups:\n", + " unsat_colors = [sns.color_palette(\"gray\")[3]] + unsat_colors\n", + " else:\n", + " if isinstance(custom_pal, dict):\n", + " groups_in_palette = {\n", + " k: v for k, v in custom_pal.items() if k in color_groups\n", + " }\n", + "\n", + " names = groups_in_palette.keys()\n", + " unsat_colors = groups_in_palette.values()\n", + "\n", + " elif isinstance(custom_pal, list):\n", + " unsat_colors = custom_pal[0:n_groups]\n", + "\n", + " elif isinstance(custom_pal, str):\n", + " # check it is in the list of matplotlib palettes.\n", + " if custom_pal in plt.colormaps():\n", + " unsat_colors = sns.color_palette(custom_pal, n_groups)\n", + " else:\n", + " err1 = \"The specified `custom_palette` {}\".format(custom_pal)\n", + " err2 = \" is not a matplotlib palette. Please check.\"\n", + " raise ValueError(err1 + err2)\n", + "\n", + " if custom_pal is None and color_col is None:\n", + " swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]\n", + " contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]\n", + " bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]\n", + " if color_by_subgroups:\n", + " plot_palette_raw = dict()\n", + " plot_palette_contrast = dict()\n", + " # plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots\n", + " plot_palette_bar = None\n", + " for i in range(len(idx)):\n", + " for names_i in idx[i]:\n", + " plot_palette_raw[names_i] = swarm_colors[i]\n", + " plot_palette_contrast[names_i] = contrast_colors[i]\n", + " else:\n", + " plot_palette_raw = dict(zip(names.categories, swarm_colors))\n", + " plot_palette_contrast = dict(zip(names.categories, contrast_colors))\n", + " plot_palette_bar = dict(zip(names.categories, bar_color))\n", + "\n", + " # For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors\n", + " # default color palette will be set to \"hls\"\n", + " plot_palette_sankey = None\n", + "\n", + " else:\n", + " swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]\n", + " contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]\n", + " bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]\n", + " if color_by_subgroups:\n", + " plot_palette_raw = dict()\n", + " plot_palette_contrast = dict()\n", + " # plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots\n", + " plot_palette_bar = None\n", + " for i in range(len(idx)):\n", + " for names_i in idx[i]:\n", + " plot_palette_raw[names_i] = swarm_colors[i]\n", + " plot_palette_contrast[names_i] = contrast_colors[i]\n", + " else:\n", + " plot_palette_raw = dict(zip(names, swarm_colors))\n", + " plot_palette_contrast = dict(zip(names, contrast_colors))\n", + " plot_palette_bar = dict(zip(names, bar_color))\n", + "\n", + " plot_palette_sankey = custom_pal\n", + "\n", + " return (color_col, bootstraps_color_by_group, n_groups, filled, swarm_colors, plot_palette_raw, \n", + " bar_color, plot_palette_bar, plot_palette_contrast, plot_palette_sankey)\n", + "\n", + "def initialize_fig(plot_kwargs, dabest_obj, show_delta2, show_mini_meta, is_paired, show_pairs, proportional,\n", + " float_contrast):\n", + " # Params\n", + " fig_size = plot_kwargs[\"fig_size\"]\n", + " face_color = plot_kwargs[\"face_color\"]\n", + " if plot_kwargs[\"face_color\"] is None:\n", + " face_color = \"white\"\n", + "\n", + " if fig_size is None:\n", + " all_groups_count = np.sum([len(i) for i in dabest_obj.idx])\n", + " # Increase the width for delta-delta graph\n", + " if show_delta2 or show_mini_meta:\n", + " all_groups_count += 2\n", + " if is_paired and show_pairs and proportional is False:\n", + " frac = 0.8\n", + " else:\n", + " frac = 1\n", + " if float_contrast:\n", + " height_inches = 4\n", + " each_group_width_inches = 2.5 * frac\n", + " else:\n", + " height_inches = 6\n", + " each_group_width_inches = 1.5 * frac\n", + "\n", + " width_inches = each_group_width_inches * all_groups_count\n", + " fig_size = (width_inches, height_inches)\n", + "\n", + " init_fig_kwargs = dict(figsize=fig_size, dpi=plot_kwargs[\"dpi\"], tight_layout=True)\n", + " width_ratios_ga = [2.5, 1]\n", + "\n", + " h_space_cummings = 0.3 if plot_kwargs[\"gridkey_rows\"] == None else 0.1 ##### GRIDKEY WIP addition\n", + "\n", + " if plot_kwargs[\"ax\"] is not None:\n", + " # New in v0.2.6.\n", + " # Use inset axes to create the estimation plot inside a single axes.\n", + " # Author: Adam L Nekimken. (PR #73)\n", + " rawdata_axes = plot_kwargs[\"ax\"]\n", + " ax_position = rawdata_axes.get_position() # [[x0, y0], [x1, y1]]\n", + "\n", + " fig = rawdata_axes.get_figure()\n", + " fig.patch.set_facecolor(face_color)\n", + "\n", + " if float_contrast:\n", + " axins = rawdata_axes.inset_axes(\n", + " [1, 0, width_ratios_ga[1] / width_ratios_ga[0], 1]\n", + " )\n", + " rawdata_axes.set_position( # [l, b, w, h]\n", + " [\n", + " ax_position.x0,\n", + " ax_position.y0,\n", + " (ax_position.x1 - ax_position.x0)\n", + " * (width_ratios_ga[0] / sum(width_ratios_ga)),\n", + " (ax_position.y1 - ax_position.y0),\n", + " ]\n", + " )\n", + "\n", + " contrast_axes = axins\n", + " else:\n", + " axins = rawdata_axes.inset_axes([0, -1 - h_space_cummings, 1, 1])\n", + " plot_height = (ax_position.y1 - ax_position.y0) / (2 + h_space_cummings)\n", + " rawdata_axes.set_position(\n", + " [\n", + " ax_position.x0,\n", + " ax_position.y0 + (1 + h_space_cummings) * plot_height,\n", + " (ax_position.x1 - ax_position.x0),\n", + " plot_height,\n", + " ]\n", + " )\n", + "\n", + " contrast_axes = axins\n", + " rawdata_axes.contrast_axes = axins\n", + "\n", + " else:\n", + " # Here, we hardcode some figure parameters.\n", + " if float_contrast:\n", + " fig, axx = plt.subplots(\n", + " ncols=2,\n", + " gridspec_kw={\"width_ratios\": width_ratios_ga, \"wspace\": 0},\n", + " **init_fig_kwargs\n", + " )\n", + " fig.patch.set_facecolor(face_color)\n", + "\n", + " else:\n", + " fig, axx = plt.subplots(\n", + " nrows=2, gridspec_kw={\"hspace\": h_space_cummings}, **init_fig_kwargs\n", + " )\n", + " fig.patch.set_facecolor(face_color)\n", + "\n", + " # Title\n", + " title = plot_kwargs[\"title\"]\n", + " fontsize_title = plot_kwargs[\"fontsize_title\"]\n", + " if title is not None:\n", + " fig.suptitle(title, fontsize=fontsize_title)\n", + " rawdata_axes = axx[0]\n", + " contrast_axes = axx[1]\n", + " rawdata_axes.set_frame_on(False)\n", + " contrast_axes.set_frame_on(False)\n", + "\n", + " swarm_ylim = plot_kwargs[\"swarm_ylim\"]\n", + " if swarm_ylim is not None:\n", + " rawdata_axes.set_ylim(swarm_ylim)\n", + "\n", + " return fig, rawdata_axes, contrast_axes, swarm_ylim\n", + "\n", + "def get_plot_groups(is_paired, idx, proportional, all_plot_groups):\n", + "\n", + " if is_paired == \"baseline\":\n", + " idx_pairs = [\n", + " (control, test)\n", + " for i in idx\n", + " for control, test in zip([i[0]] * (len(i) - 1), i[1:])\n", + " ]\n", + " temp_idx = idx if not proportional else idx_pairs\n", + " else:\n", + " idx_pairs = [\n", + " (control, test) for i in idx for control, test in zip(i[:-1], i[1:])\n", + " ]\n", + " temp_idx = idx if not proportional else idx_pairs\n", + "\n", + " # Determine temp_all_plot_groups based on proportional condition\n", + " plot_groups = [item for i in temp_idx for item in i]\n", + " temp_all_plot_groups = all_plot_groups if not proportional else plot_groups\n", + " \n", + " return temp_idx, temp_all_plot_groups\n", + "\n", + "\n", + "def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs):\n", + " # Add the counts to the rawdata axes xticks.\n", + " counts = plot_data.groupby(xvar).count()[yvar]\n", + " ticks_with_counts = []\n", + " ticks_loc = rawdata_axes.get_xticks()\n", + " rawdata_axes.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(ticks_loc))\n", + " def lookup_value(text, counts):\n", + " try:\n", + " return str(counts.loc[text])\n", + " except KeyError:\n", + " try:\n", + " numeric_key = pd.to_numeric(text, errors='coerce')\n", + " if pd.notnull(numeric_key):\n", + " return str(counts.loc[numeric_key])\n", + " else:\n", + " raise ValueError\n", + " except (ValueError, KeyError):\n", + " print(f\"Key '{text}' not found in counts.\")\n", + " return \"N/A\"\n", + " for xticklab in rawdata_axes.xaxis.get_ticklabels():\n", + " t = xticklab.get_text()\n", + " # Extract the text after the last newline, if present\n", + " if t.rfind(\"\\n\") != -1:\n", + " te = t[t.rfind(\"\\n\") + len(\"\\n\"):]\n", + " value = lookup_value(te, counts)\n", + " te = t\n", + " else:\n", + " te = t\n", + " value = lookup_value(te, counts)\n", + "\n", + " # Append the modified tick label with the count to the list\n", + " ticks_with_counts.append(f\"{te}\\nN = {value}\")\n", + "\n", + "\n", + " if plot_kwargs[\"fontsize_rawxlabel\"] is not None:\n", + " fontsize_rawxlabel = plot_kwargs[\"fontsize_rawxlabel\"]\n", + " rawdata_axes.set_xticklabels(ticks_with_counts, fontsize=fontsize_rawxlabel)\n", + "\n", + "\n", + "def extract_contrast_plotting_ticks(is_paired, show_pairs, two_col_sankey, plot_groups, idx, sankey_control_group):\n", + "\n", + " # Take note of where the `control` groups are.\n", + " ticks_to_skip_contrast = None\n", + " ticks_to_start_twocol_sankey = None\n", + " if is_paired == \"baseline\" and show_pairs:\n", + " if two_col_sankey:\n", + " ticks_to_skip = []\n", + " ticks_to_plot = np.arange(0, len(plot_groups) / 2).tolist()\n", + " ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist()\n", + " ticks_to_start_twocol_sankey.pop()\n", + " ticks_to_start_twocol_sankey.insert(0, 0)\n", + " else:\n", + " # ticks_to_skip = np.arange(0, len(temp_all_plot_groups), 2).tolist()\n", + " # ticks_to_plot = np.arange(1, len(temp_all_plot_groups), 2).tolist()\n", + " ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()\n", + " ticks_to_skip.insert(0, 0)\n", + " # Then obtain the ticks where we have to plot the effect sizes.\n", + " ticks_to_plot = [\n", + " t for t in range(0, len(plot_groups)) if t not in ticks_to_skip\n", + " ]\n", + " ticks_to_skip_contrast = np.cumsum([(len(t)) for t in idx])[:-1].tolist()\n", + " ticks_to_skip_contrast.insert(0, 0)\n", + " else:\n", + " if two_col_sankey:\n", + " ticks_to_skip = [len(sankey_control_group)]\n", + " # Then obtain the ticks where we have to plot the effect sizes.\n", + " ticks_to_plot = [\n", + " t for t in range(0, len(plot_groups)) if t not in ticks_to_skip\n", + " ]\n", + " ticks_to_skip = []\n", + " ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist()\n", + " ticks_to_start_twocol_sankey.pop()\n", + " ticks_to_start_twocol_sankey.insert(0, 0)\n", + " else:\n", + " ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()\n", + " ticks_to_skip.insert(0, 0)\n", + " # Then obtain the ticks where we have to plot the effect sizes.\n", + " ticks_to_plot = [\n", + " t for t in range(0, len(plot_groups)) if t not in ticks_to_skip\n", + " ]\n", + " \n", + " return ticks_to_skip, ticks_to_plot, ticks_to_skip_contrast, ticks_to_start_twocol_sankey\n", + "\n", + "def set_xaxis_ticks_and_lims(show_delta2, show_mini_meta, rawdata_axes, contrast_axes, show_pairs, float_contrast,\n", + " ticks_to_skip, contrast_xtick_labels, plot_kwargs):\n", + "\n", + " if show_delta2 is False and show_mini_meta is False:\n", + " contrast_axes.set_xticks(rawdata_axes.get_xticks())\n", + " else:\n", + " temp = rawdata_axes.get_xticks()\n", + " temp = np.append(temp, [max(temp) + 1, max(temp) + 2])\n", + " contrast_axes.set_xticks(temp)\n", + "\n", + " if show_pairs:\n", + " max_x = contrast_axes.get_xlim()[1]\n", + " rawdata_axes.set_xlim(-0.375, max_x)\n", + "\n", + " if float_contrast:\n", + " contrast_axes.set_xlim(0.5, 1.5)\n", + " elif show_delta2 or show_mini_meta:\n", + " # Increase the xlim of raw data by 2\n", + " temp = rawdata_axes.get_xlim()\n", + " if show_pairs:\n", + " rawdata_axes.set_xlim(temp[0], temp[1] + 0.25)\n", + " else:\n", + " rawdata_axes.set_xlim(temp[0], temp[1] + 2)\n", + " contrast_axes.set_xlim(rawdata_axes.get_xlim())\n", + " else:\n", + " contrast_axes.set_xlim(rawdata_axes.get_xlim())\n", + "\n", + " # Properly label the contrast ticks.\n", + " for t in ticks_to_skip:\n", + " contrast_xtick_labels.insert(t, \"\")\n", + "\n", + " if plot_kwargs[\"fontsize_contrastxlabel\"] is not None:\n", + " fontsize_contrastxlabel = plot_kwargs[\"fontsize_contrastxlabel\"]\n", + "\n", + " contrast_axes.set_xticklabels(\n", + " contrast_xtick_labels, fontsize=fontsize_contrastxlabel\n", + " )\n", + "\n", + "\n", + "def show_legend(legend_labels, legend_handles, rawdata_axes, contrast_axes, float_contrast, show_pairs, legend_kwargs):\n", + "\n", + " legend_labels_unique = np.unique(legend_labels)\n", + " unique_idx = np.unique(legend_labels, return_index=True)[1]\n", + " legend_handles_unique = (\n", + " pd.Series(legend_handles, dtype=\"object\").loc[unique_idx]\n", + " ).tolist()\n", + "\n", + " if len(legend_handles_unique) > 0:\n", + " if float_contrast:\n", + " axes_with_legend = contrast_axes\n", + " if show_pairs:\n", + " bta = (2.00, 1.02)\n", + " else:\n", + " bta = (1.5, 1.02)\n", + " else:\n", + " axes_with_legend = rawdata_axes\n", + " if show_pairs:\n", + " bta = (1.02, 1.0)\n", + " else:\n", + " bta = (1.0, 1.0)\n", + " leg = axes_with_legend.legend(\n", + " legend_handles_unique,\n", + " legend_labels_unique,\n", + " bbox_to_anchor=bta,\n", + " **legend_kwargs\n", + " )\n", + " if show_pairs:\n", + " for line in leg.get_lines():\n", + " line.set_linewidth(3.0)\n", + " \n", + "def Gardner_Altman_Plot_Aesthetic_Adjustments(effect_size_type, plot_data, xvar, yvar, current_control, current_group,\n", + " rawdata_axes, contrast_axes, results, current_effsize, is_paired, one_sankey,\n", + " reflines_kwargs, redraw_axes_kwargs, swarm_ylim, og_xlim_raw, og_ylim_raw):\n", + " from ._stats_tools.effsize import (\n", + " _compute_standardizers,\n", + " _compute_hedges_correction_factor,\n", + " )\n", + " # Normalize ylims and despine the floating contrast axes.\n", + " # Check that the effect size is within the swarm ylims.\n", + " if effect_size_type in [\"mean_diff\", \"cohens_d\", \"hedges_g\", \"cohens_h\"]:\n", + " control_group_summary = (\n", + " plot_data.groupby(xvar)\n", + " .mean(numeric_only=True)\n", + " .loc[current_control, yvar]\n", + " )\n", + " test_group_summary = (\n", + " plot_data.groupby(xvar).mean(numeric_only=True).loc[current_group, yvar]\n", + " )\n", + " elif effect_size_type == \"median_diff\":\n", + " control_group_summary = (\n", + " plot_data.groupby(xvar).median().loc[current_control, yvar]\n", + " )\n", + " test_group_summary = (\n", + " plot_data.groupby(xvar).median().loc[current_group, yvar]\n", + " )\n", + "\n", + " if swarm_ylim is None:\n", + " swarm_ylim = rawdata_axes.get_ylim()\n", + "\n", + " _, contrast_xlim_max = contrast_axes.get_xlim()\n", + "\n", + " difference = float(results.difference[0])\n", + "\n", + " if effect_size_type in [\"mean_diff\", \"median_diff\"]:\n", + " # Align 0 of contrast_axes to reference group mean of rawdata_axes.\n", + " # If the effect size is positive, shift the contrast axis up.\n", + " rawdata_ylims = np.array(rawdata_axes.get_ylim())\n", + " if current_effsize > 0:\n", + " rightmin, rightmax = rawdata_ylims - current_effsize\n", + " # If the effect size is negative, shift the contrast axis down.\n", + " elif current_effsize < 0:\n", + " rightmin, rightmax = rawdata_ylims + current_effsize\n", + " else:\n", + " rightmin, rightmax = rawdata_ylims\n", + "\n", + " contrast_axes.set_ylim(rightmin, rightmax)\n", + "\n", + " og_ylim_contrast = rawdata_axes.get_ylim() - np.array(control_group_summary)\n", + "\n", + " contrast_axes.set_ylim(og_ylim_contrast)\n", + " contrast_axes.set_xlim(contrast_xlim_max - 1, contrast_xlim_max)\n", + "\n", + " elif effect_size_type in [\"cohens_d\", \"hedges_g\", \"cohens_h\"]:\n", + " if is_paired:\n", + " which_std = 1\n", + " else:\n", + " which_std = 0\n", + " temp_control = plot_data[plot_data[xvar] == current_control][yvar]\n", + " temp_test = plot_data[plot_data[xvar] == current_group][yvar]\n", + "\n", + " stds = _compute_standardizers(temp_control, temp_test)\n", + " if is_paired:\n", + " pooled_sd = stds[1]\n", + " else:\n", + " pooled_sd = stds[0]\n", + "\n", + " if effect_size_type == \"hedges_g\":\n", + " gby_count = plot_data.groupby(xvar).count()\n", + " len_control = gby_count.loc[current_control, yvar]\n", + " len_test = gby_count.loc[current_group, yvar]\n", + "\n", + " hg_correction_factor = _compute_hedges_correction_factor(\n", + " len_control, len_test\n", + " )\n", + "\n", + " ylim_scale_factor = pooled_sd / hg_correction_factor\n", + "\n", + " elif effect_size_type == \"cohens_h\":\n", + " ylim_scale_factor = (\n", + " np.mean(temp_test) - np.mean(temp_control)\n", + " ) / difference\n", + "\n", + " else:\n", + " ylim_scale_factor = pooled_sd\n", + "\n", + " scaled_ylim = (\n", + " (rawdata_axes.get_ylim() - control_group_summary) / ylim_scale_factor\n", + " ).tolist()\n", + "\n", + " contrast_axes.set_ylim(scaled_ylim)\n", + " og_ylim_contrast = scaled_ylim\n", + "\n", + " contrast_axes.set_xlim(contrast_xlim_max - 1, contrast_xlim_max)\n", + "\n", + " if one_sankey is None:\n", + " # Draw summary lines for control and test groups..\n", + " for jj, axx in enumerate([rawdata_axes, contrast_axes]):\n", + " # Draw effect size line.\n", + " if jj == 0:\n", + " ref = control_group_summary\n", + " diff = test_group_summary\n", + " effsize_line_start = 1\n", + "\n", + " elif jj == 1:\n", + " ref = 0\n", + " diff = ref + difference\n", + " effsize_line_start = contrast_xlim_max - 1.1\n", + "\n", + " xlimlow, xlimhigh = axx.get_xlim()\n", + "\n", + " # Draw reference line.\n", + " axx.hlines(\n", + " ref, # y-coordinates\n", + " 0,\n", + " xlimhigh, # x-coordinates, start and end.\n", + " **reflines_kwargs\n", + " )\n", + "\n", + " # Draw effect size line.\n", + " axx.hlines(diff, effsize_line_start, xlimhigh, **reflines_kwargs)\n", + " else:\n", + " ref = 0\n", + " diff = ref + difference\n", + " effsize_line_start = contrast_xlim_max - 0.9\n", + " xlimlow, xlimhigh = contrast_axes.get_xlim()\n", + " # Draw reference line.\n", + " contrast_axes.hlines(\n", + " ref, # y-coordinates\n", + " effsize_line_start,\n", + " xlimhigh, # x-coordinates, start and end.\n", + " **reflines_kwargs\n", + " )\n", + "\n", + " # Draw effect size line.\n", + " contrast_axes.hlines(diff, effsize_line_start, xlimhigh, **reflines_kwargs)\n", + " rawdata_axes.set_xlim(og_xlim_raw) # to align the axis\n", + " # Despine appropriately.\n", + " sns.despine(ax=rawdata_axes, bottom=True)\n", + " sns.despine(ax=contrast_axes, left=True, right=False)\n", + "\n", + " # Insert break between the rawdata axes and the contrast axes\n", + " # by re-drawing the x-spine.\n", + " rawdata_axes.hlines(\n", + " og_ylim_raw[0], # yindex\n", + " rawdata_axes.get_xlim()[0],\n", + " 1.3, # xmin, xmax\n", + " **redraw_axes_kwargs\n", + " )\n", + " rawdata_axes.set_ylim(og_ylim_raw)\n", + "\n", + " contrast_axes.hlines(\n", + " contrast_axes.get_ylim()[0],\n", + " contrast_xlim_max - 0.8,\n", + " contrast_xlim_max,\n", + " **redraw_axes_kwargs\n", + " )\n", + "\n", + "\n", + "def Cumming_Plot_Aesthetic_Adjustments(plot_kwargs, show_delta2, effect_size_type, contrast_axes, reflines_kwargs, \n", + " is_paired, show_pairs, two_col_sankey, idx, ticks_to_start_twocol_sankey,\n", + " proportional, ticks_to_skip, temp_idx, rawdata_axes, redraw_axes_kwargs,\n", + " ticks_to_skip_contrast):\n", + " # Set custom contrast_ylim, if it was specified.\n", + " if plot_kwargs[\"contrast_ylim\"] is not None or (\n", + " plot_kwargs[\"delta2_ylim\"] is not None and show_delta2\n", + " ):\n", + " if plot_kwargs[\"contrast_ylim\"] is not None:\n", + " custom_contrast_ylim = plot_kwargs[\"contrast_ylim\"]\n", + " if plot_kwargs[\"delta2_ylim\"] is not None and show_delta2:\n", + " custom_delta2_ylim = plot_kwargs[\"delta2_ylim\"]\n", + " if custom_contrast_ylim != custom_delta2_ylim:\n", + " err1 = \"Please check if `contrast_ylim` and `delta2_ylim` are assigned\"\n", + " err2 = \"with same values.\"\n", + " raise ValueError(err1 + err2)\n", + " else:\n", + " custom_delta2_ylim = plot_kwargs[\"delta2_ylim\"]\n", + " custom_contrast_ylim = custom_delta2_ylim\n", + "\n", + " if len(custom_contrast_ylim) != 2:\n", + " err1 = \"Please check `contrast_ylim` consists of \"\n", + " err2 = \"exactly two numbers.\"\n", + " raise ValueError(err1 + err2)\n", + "\n", + " if effect_size_type == \"cliffs_delta\":\n", + " # Ensure the ylims for a cliffs_delta plot never exceed [-1, 1].\n", + " l = plot_kwargs[\"contrast_ylim\"][0]\n", + " h = plot_kwargs[\"contrast_ylim\"][1]\n", + " low = -1 if l < -1 else l\n", + " high = 1 if h > 1 else h\n", + " contrast_axes.set_ylim(low, high)\n", + " else:\n", + " contrast_axes.set_ylim(custom_contrast_ylim)\n", + "\n", + "\n", + " # If 0 lies within the ylim of the contrast axes,\n", + " # draw a zero reference line.\n", + " contrast_axes_ylim = contrast_axes.get_ylim()\n", + " if contrast_axes_ylim[0] < contrast_axes_ylim[1]:\n", + " contrast_ylim_low, contrast_ylim_high = contrast_axes_ylim\n", + " else:\n", + " contrast_ylim_high, contrast_ylim_low = contrast_axes_ylim\n", + " if contrast_ylim_low < 0 < contrast_ylim_high:\n", + " contrast_axes.axhline(y=0, **reflines_kwargs)\n", + "\n", + " if is_paired == \"baseline\" and show_pairs:\n", + " if two_col_sankey:\n", + " rightend_ticks_raw = np.array([len(i) - 2 for i in idx]) + np.array(\n", + " ticks_to_start_twocol_sankey\n", + " )\n", + " elif proportional and is_paired is not None:\n", + " rightend_ticks_raw = np.array([len(i) - 1 for i in idx]) + np.array(\n", + " ticks_to_skip\n", + " )\n", + " else:\n", + " rightend_ticks_raw = np.array(\n", + " [len(i) - 1 for i in temp_idx]\n", + " ) + np.array(ticks_to_skip)\n", + " for ax in [rawdata_axes]:\n", + " sns.despine(ax=ax, bottom=True)\n", + "\n", + " ylim = ax.get_ylim()\n", + " xlim = ax.get_xlim()\n", + " redraw_axes_kwargs[\"y\"] = ylim[0]\n", + "\n", + " if two_col_sankey:\n", + " for k, start_tick in enumerate(ticks_to_start_twocol_sankey):\n", + " end_tick = rightend_ticks_raw[k]\n", + " ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs)\n", + " else:\n", + " for k, start_tick in enumerate(ticks_to_skip):\n", + " end_tick = rightend_ticks_raw[k]\n", + " ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs)\n", + " ax.set_ylim(ylim)\n", + " del redraw_axes_kwargs[\"y\"]\n", + "\n", + " if not proportional:\n", + " temp_length = [(len(i) - 1) for i in idx]\n", + " else:\n", + " temp_length = [(len(i) - 1) * 2 - 1 for i in idx]\n", + " if two_col_sankey:\n", + " rightend_ticks_contrast = np.array(\n", + " [len(i) - 2 for i in idx]\n", + " ) + np.array(ticks_to_start_twocol_sankey)\n", + " elif proportional and is_paired is not None:\n", + " rightend_ticks_contrast = np.array(\n", + " [len(i) - 1 for i in idx]\n", + " ) + np.array(ticks_to_skip)\n", + " else:\n", + " rightend_ticks_contrast = np.array(temp_length) + np.array(\n", + " ticks_to_skip_contrast\n", + " )\n", + " for ax in [contrast_axes]:\n", + " sns.despine(ax=ax, bottom=True)\n", + "\n", + " ylim = ax.get_ylim()\n", + " xlim = ax.get_xlim()\n", + " redraw_axes_kwargs[\"y\"] = ylim[0]\n", + "\n", + " if two_col_sankey:\n", + " for k, start_tick in enumerate(ticks_to_start_twocol_sankey):\n", + " end_tick = rightend_ticks_contrast[k]\n", + " ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs)\n", + " else:\n", + " for k, start_tick in enumerate(ticks_to_skip_contrast):\n", + " end_tick = rightend_ticks_contrast[k]\n", + " ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs)\n", + "\n", + " ax.set_ylim(ylim)\n", + " del redraw_axes_kwargs[\"y\"]\n", + " else:\n", + " # Compute the end of each x-axes line.\n", + " if two_col_sankey:\n", + " rightend_ticks = np.array([len(i) - 2 for i in idx]) + np.array(\n", + " ticks_to_start_twocol_sankey\n", + " )\n", + " else:\n", + " rightend_ticks = np.array([len(i) - 1 for i in idx]) + np.array(\n", + " ticks_to_skip\n", + " )\n", + "\n", + " for ax in [rawdata_axes, contrast_axes]:\n", + " sns.despine(ax=ax, bottom=True)\n", + "\n", + " ylim = ax.get_ylim()\n", + " xlim = ax.get_xlim()\n", + " redraw_axes_kwargs[\"y\"] = ylim[0]\n", + "\n", + " if two_col_sankey:\n", + " for k, start_tick in enumerate(ticks_to_start_twocol_sankey):\n", + " end_tick = rightend_ticks[k]\n", + " ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs)\n", + " else:\n", + " for k, start_tick in enumerate(ticks_to_skip):\n", + " end_tick = rightend_ticks[k]\n", + " ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs)\n", + "\n", + " ax.set_ylim(ylim)\n", + " del redraw_axes_kwargs[\"y\"]\n", + "\n", + "def General_Plot_Aesthetic_Adjustments(show_delta2, show_mini_meta, contrast_axes, redraw_axes_kwargs, plot_kwargs,\n", + " yvar, effect_size_type, proportional, effectsize_df, is_paired, float_contrast,\n", + " rawdata_axes, og_ylim_raw, effect_size):\n", + "\n", + " if show_delta2 or show_mini_meta:\n", + " ylim = contrast_axes.get_ylim()\n", + " redraw_axes_kwargs[\"y\"] = ylim[0]\n", + " x_ticks = contrast_axes.get_xticks()\n", + " contrast_axes.hlines(xmin=x_ticks[-2], xmax=x_ticks[-1], **redraw_axes_kwargs)\n", + " del redraw_axes_kwargs[\"y\"]\n", + "\n", + " # Set raw axes y-label.\n", + " swarm_label = plot_kwargs[\"swarm_label\"]\n", + " if swarm_label is None and yvar is None:\n", + " swarm_label = \"value\"\n", + " elif swarm_label is None and yvar is not None:\n", + " swarm_label = yvar\n", + "\n", + " bar_label = plot_kwargs[\"bar_label\"]\n", + " if bar_label is None and effect_size_type != \"cohens_h\":\n", + " bar_label = \"proportion of success\"\n", + " elif bar_label is None and effect_size_type == \"cohens_h\":\n", + " bar_label = \"value\"\n", + "\n", + " # Place contrast axes y-label.\n", + " contrast_label_dict = {\n", + " \"mean_diff\": \"mean difference\",\n", + " \"median_diff\": \"median difference\",\n", + " \"cohens_d\": \"Cohen's d\",\n", + " \"hedges_g\": \"Hedges' g\",\n", + " \"cliffs_delta\": \"Cliff's delta\",\n", + " \"cohens_h\": \"Cohen's h\",\n", + " \"delta_g\": \"mean difference\",\n", + " }\n", + "\n", + " if proportional and effect_size_type != \"cohens_h\":\n", + " default_contrast_label = \"proportion difference\"\n", + " elif effect_size_type == \"delta_g\":\n", + " default_contrast_label = \"Hedges' g\"\n", + " else:\n", + " default_contrast_label = contrast_label_dict[effectsize_df.effect_size]\n", + "\n", + " if plot_kwargs[\"contrast_label\"] is None:\n", + " if is_paired:\n", + " contrast_label = \"paired\\n{}\".format(default_contrast_label)\n", + " else:\n", + " contrast_label = default_contrast_label\n", + " contrast_label = contrast_label.capitalize()\n", + " else:\n", + " contrast_label = plot_kwargs[\"contrast_label\"]\n", + "\n", + " if plot_kwargs[\"fontsize_rawylabel\"] is not None:\n", + " fontsize_rawylabel = plot_kwargs[\"fontsize_rawylabel\"]\n", + " if plot_kwargs[\"fontsize_contrastylabel\"] is not None:\n", + " fontsize_contrastylabel = plot_kwargs[\"fontsize_contrastylabel\"]\n", + " if plot_kwargs[\"fontsize_delta2label\"] is not None:\n", + " fontsize_delta2label = plot_kwargs[\"fontsize_delta2label\"]\n", + "\n", + " contrast_axes.set_ylabel(contrast_label, fontsize=fontsize_contrastylabel)\n", + " if float_contrast:\n", + " contrast_axes.yaxis.set_label_position(\"right\")\n", + "\n", + " # Set the rawdata axes labels appropriately\n", + " if not proportional:\n", + " rawdata_axes.set_ylabel(swarm_label, fontsize=fontsize_rawylabel)\n", + " else:\n", + " rawdata_axes.set_ylabel(bar_label, fontsize=fontsize_rawylabel)\n", + " rawdata_axes.set_xlabel(\"\")\n", + "\n", + " # Because we turned the axes frame off, we also need to draw back\n", + " # the y-spine for both axes.\n", + " if not float_contrast:\n", + " rawdata_axes.set_xlim(contrast_axes.get_xlim())\n", + " og_xlim_raw = rawdata_axes.get_xlim()\n", + " rawdata_axes.vlines(\n", + " og_xlim_raw[0], og_ylim_raw[0], og_ylim_raw[1], **redraw_axes_kwargs\n", + " )\n", + "\n", + " og_xlim_contrast = contrast_axes.get_xlim()\n", + "\n", + " if float_contrast:\n", + " xpos = og_xlim_contrast[1]\n", + " else:\n", + " xpos = og_xlim_contrast[0]\n", + "\n", + " og_ylim_contrast = contrast_axes.get_ylim()\n", + " contrast_axes.vlines(\n", + " xpos, og_ylim_contrast[0], og_ylim_contrast[1], **redraw_axes_kwargs\n", + " )\n", + "\n", + " if show_delta2:\n", + " if plot_kwargs[\"delta2_label\"] is not None:\n", + " delta2_label = plot_kwargs[\"delta2_label\"]\n", + " elif effect_size == \"mean_diff\":\n", + " delta2_label = \"delta - delta\"\n", + " else:\n", + " delta2_label = \"deltas' g\"\n", + " delta2_axes = contrast_axes.twinx()\n", + " delta2_axes.set_frame_on(False)\n", + " delta2_axes.set_ylabel(delta2_label, fontsize=fontsize_delta2label)\n", + " og_xlim_delta = contrast_axes.get_xlim()\n", + " og_ylim_delta = contrast_axes.get_ylim()\n", + " delta2_axes.set_ylim(og_ylim_delta)\n", + " delta2_axes.vlines(\n", + " og_xlim_delta[1], og_ylim_delta[0], og_ylim_delta[1], **redraw_axes_kwargs\n", + " )" ] } ], diff --git a/nbs/API/plot_tools.ipynb b/nbs/API/plot_tools.ipynb index 7187025b..75574700 100644 --- a/nbs/API/plot_tools.ipynb +++ b/nbs/API/plot_tools.ipynb @@ -64,6 +64,7 @@ "import matplotlib.pyplot as plt\n", "import matplotlib.lines as mlines\n", "import matplotlib.axes as axes\n", + "import matplotlib.patches as mpatches\n", "from collections import defaultdict\n", "from typing import List, Tuple, Dict, Iterable, Union\n", "from pandas.api.types import CategoricalDtype\n", @@ -228,7 +229,8 @@ "\n", " kwargs[\"zorder\"] = kwargs[\"zorder\"]\n", "\n", - " for xpos, central_measure in enumerate(central_measures):\n", + " for xpos, val in enumerate(central_measures.index):\n", + " central_measure = central_measures[val]\n", " kwargs[\"color\"] = custom_palette[xpos]\n", "\n", " if method == \"sankey_error_bar\":\n", @@ -236,8 +238,14 @@ " else:\n", " _xpos = xpos + offset[xpos]\n", "\n", - " low = lows[xpos]\n", - " high = highs[xpos]\n", + " # Fix for the non-string x-axis issue #108\n", + " if central_measures.index.dtype.name == \"category\":\n", + " low = lows[xpos]\n", + " high = highs[xpos]\n", + " else: \n", + " low = lows[val]\n", + " high = highs[val]\n", + "\n", " if low == high == central_measure:\n", " low_to_mean = mlines.Line2D(\n", " [_xpos, _xpos], [low, central_measure], **kwargs\n", @@ -408,7 +416,7 @@ " strip_on: bool = True, # if True, draw strip for each group comparison\n", " one_sankey: bool = False, # if True, only draw one sankey diagram\n", " right_color: bool = False, # if True, each strip of the diagram will be colored according to the corresponding left labels\n", - " align: bool = \"center\", # if 'center', the diagram will be centered on each xtick, if 'edge', the diagram will be aligned with the left edge of each xtick\n", + " align: str = \"center\", # if 'center', the diagram will be centered on each xtick, if 'edge', the diagram will be aligned with the left edge of each xtick\n", "):\n", " \"\"\"\n", " Make a single Sankey diagram showing proportion flow from left to right\n", @@ -489,6 +497,7 @@ " if align not in (\"center\", \"edge\"):\n", " err = \"{} assigned for `align` is not valid.\".format(align)\n", " raise ValueError(err)\n", + " \n", " if align == \"center\":\n", " try:\n", " leftpos = xpos - width / 2\n", @@ -675,8 +684,9 @@ " data: pd.DataFrame,\n", " xvar: str, # x column to be plotted.\n", " yvar: str, # y column to be plotted.\n", - " left_idx: str, # the value in column xvar that is on the left side of each sankey diagram\n", - " right_idx: str, # the value in column xvar that is on the right side of each sankey diagram, if len(left_idx) == 1, it will be broadcasted to the same length as right_idx, otherwise it should have the same length as right_idx\n", + " temp_all_plot_groups: list,\n", + " idx: list,\n", + " temp_idx: list,\n", " left_labels: list = None, # labels for the left side of the diagram. The diagram will be sorted by these labels.\n", " right_labels: list = None, # labels for the right side of the diagram. The diagram will be sorted by these labels.\n", " palette: str | dict = None,\n", @@ -722,6 +732,30 @@ " if ax is None:\n", " ax = plt.gca()\n", "\n", + " left_idx = []\n", + " right_idx = []\n", + " # Design for Sankey Flow Diagram\n", + " sankey_idx = (\n", + " [\n", + " (control, test)\n", + " for i in idx\n", + " for control, test in zip(i[:], (i[1:] + (i[0],)))\n", + " ]\n", + " if flow\n", + " else temp_idx\n", + " )\n", + " for i in sankey_idx:\n", + " left_idx.append(i[0])\n", + " right_idx.append(i[1])\n", + "\n", + " if len(temp_all_plot_groups) == 2:\n", + " one_sankey = True\n", + " left_idx.pop()\n", + " right_idx.pop() # Remove the last element from two lists\n", + "\n", + " # two_col_sankey = True if proportional == True and one_sankey == False and sankey == True and flow == False else False\n", + "\n", + "\n", " allLabels = pd.Series(np.sort(data[yvar].unique())[::-1]).unique()\n", "\n", " # Check if all the elements in left_idx and right_idx are in xvar column\n", @@ -831,7 +865,672 @@ " else:\n", " sankey_ticks = [broadcasted_left[0], right_idx[0]]\n", " ax.set_xticks([0, 1])\n", - " ax.set_xticklabels(sankey_ticks)" + " ax.set_xticklabels(sankey_ticks)\n", + "\n", + " return left_idx, right_idx\n", + "\n", + "def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object,\n", + " float_contrast: bool,summary_bars_kwargs: dict, ci_type: str,\n", + " ticks_to_plot: list, color_col: str, swarm_colors: list, \n", + " proportional: bool, is_paired: bool):\n", + " \"\"\"\n", + " Add summary bars to the contrast plot.\n", + "\n", + " Parameters\n", + " ----------\n", + " summary_bars : list\n", + " List of indices of the contrast objects to plot summary bars for.\n", + " results : object (Dataframe)\n", + " Dataframe of contrast object comparisons.\n", + " ax_to_plot : object\n", + " Matplotlib axis object to plot on.\n", + " float_contrast : bool\n", + " Whether the DABEST plot uses Gardner-Altman or Cummings.\n", + " summary_bars_kwargs : dict\n", + " Keyword arguments for the summary bars.\n", + " ci_type : str \n", + " Type of confidence interval to plot.\n", + " ticks_to_plot : list\n", + " List of indices of the contrast objects.\n", + " color_col : str\n", + " Column name of the color column.\n", + " swarm_colors : list\n", + " List of colors used in the plot.\n", + " proportional : bool\n", + " Whether the data is proportional.\n", + " is_paired : bool\n", + " Whether the data is paired.\n", + " \"\"\"\n", + "# Begin checks \n", + " if not isinstance(summary_bars, list):\n", + " raise TypeError(\"summary_bars must be a list of indices (ints).\")\n", + " if not all(isinstance(i, int) for i in summary_bars):\n", + " raise TypeError(\"summary_bars must be a list of indices (ints).\")\n", + " if any(i >= len(results) for i in summary_bars):\n", + " raise ValueError(\"Index {} chosen is out of range for the contrast objects.\".format([i for i in summary_bars if i >= len(results)]))\n", + " if float_contrast:\n", + " raise ValueError(\"summary_bars cannot be used with Gardner-Altman plots.\")\n", + "# End checks\n", + " else:\n", + " summary_xmin, summary_xmax = ax_to_plot.get_xlim()\n", + " summary_bars_colors = [summary_bars_kwargs.get('color')]*(max(summary_bars)+1) if summary_bars_kwargs.get('color') is not None else ['black']*(max(summary_bars)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors\n", + " summary_bars_kwargs.pop('color')\n", + " for summary_index in summary_bars:\n", + " if ci_type == \"bca\":\n", + " summary_ci_low = results.bca_low[summary_index]\n", + " summary_ci_high = results.bca_high[summary_index]\n", + " else:\n", + " summary_ci_low = results.pct_low[summary_index]\n", + " summary_ci_high = results.pct_high[summary_index]\n", + "\n", + " summary_color = summary_bars_colors[ticks_to_plot[summary_index]]\n", + "\n", + " ax_to_plot.add_patch(mpatches.Rectangle((summary_xmin,summary_ci_low),summary_xmax+1, \n", + " summary_ci_high-summary_ci_low, zorder=-2, color=summary_color, **summary_bars_kwargs))\n", + "\n", + "\n", + "def contrast_bars_plotter(results: object, ax_to_plot: object, swarm_plot_ax: object,\n", + " ticks_to_plot: list, contrast_bars_kwargs: dict, color_col: str, \n", + " plot_palette_raw: dict, show_mini_meta: bool, mini_meta_delta: object, \n", + " show_delta2: bool, delta_delta: object, proportional: bool, is_paired: bool):\n", + " \"\"\"\n", + " Add contrast bars to the contrast plot.\n", + "\n", + " Parameters\n", + " ----------\n", + " results : object (Dataframe)\n", + " Dataframe of contrast object comparisons.\n", + " ax_to_plot : object\n", + " Matplotlib axis object to plot on.\n", + " swarm_plot_ax : object (ax)\n", + " Matplotlib axis object of the swarm plot.\n", + " ticks_to_plot : list\n", + " List of indices of the contrast objects.\n", + " contrast_bars_kwargs : dict \n", + " Keyword arguments for the contrast bars.\n", + " color_col : str\n", + " Column name of the color column.\n", + " plot_palette_raw : dict\n", + " Dictionary of colors used in the plot.\n", + " show_mini_meta : bool \n", + " Whether to show the mini meta-analysis.\n", + " mini_meta_delta : object \n", + " Mini meta-analysis object.\n", + " show_delta2 : bool\n", + " Whether to show the delta-delta.\n", + " delta_delta : object\n", + " delta-delta object.\n", + " proportional : bool\n", + " Whether the data is proportional.\n", + " is_paired : bool\n", + " Whether the data is paired.\n", + " \"\"\"\n", + " contrast_means = []\n", + " for j, tick in enumerate(ticks_to_plot):\n", + " contrast_means.append(results.difference[j])\n", + "\n", + " contrast_bars_colors = (\n", + " [contrast_bars_kwargs.get('color')] * (max(ticks_to_plot) + 1) \n", + " if contrast_bars_kwargs.get('color') is not None \n", + " else ['black'] * (max(ticks_to_plot) + 1) \n", + " if color_col is not None or (proportional and is_paired) or is_paired \n", + " else list(plot_palette_raw.values())\n", + " )\n", + " contrast_bars_kwargs.pop('color')\n", + " for contrast_bars_x,contrast_bars_y in zip(ticks_to_plot, contrast_means):\n", + " ax_to_plot.add_patch(mpatches.Rectangle((contrast_bars_x-0.25,0),0.5, contrast_bars_y, zorder=-1, color=contrast_bars_colors[contrast_bars_x], **contrast_bars_kwargs))\n", + "\n", + " if show_mini_meta:\n", + " ax_to_plot.add_patch(mpatches.Rectangle((max(swarm_plot_ax.get_xticks())+2-0.25,0),0.5, mini_meta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs))\n", + "\n", + " if show_delta2:\n", + " ax_to_plot.add_patch(mpatches.Rectangle((max(swarm_plot_ax.get_xticks())+2-0.25,0),0.5, delta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs))\n", + "\n", + "def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object,\n", + " swarm_bars_kwargs: dict, color_col: str, plot_palette_raw: dict, is_paired: bool):\n", + " \"\"\"\n", + " Add bars to the raw data plot.\n", + "\n", + " Parameters\n", + " ----------\n", + " plot_data : object (Dataframe)\n", + " Dataframe of the plot data.\n", + " xvar : str\n", + " Column name of the x variable.\n", + " yvar : str\n", + " Column name of the y variable.\n", + " ax : object \n", + " Matplotlib axis object to plot on.\n", + " swarm_bars_kwargs : dict\n", + " Keyword arguments for the swarm bars.\n", + " color_col : str\n", + " Column name of the color column.\n", + " plot_palette_raw : dict\n", + " Dictionary of colors used in the plot.\n", + " is_paired : bool\n", + " Whether the data is paired.\n", + " \"\"\"\n", + "\n", + " # if is_paired:\n", + " # swarm_bar_xlocs_adjustleft = {'right': -0.2, 'left': -0.2, 'center': -0.2}\n", + " # swarm_bar_xlocs_adjustright = {'right': -0.1, 'left': -0.1, 'center': -0.1} \n", + " # else:\n", + " # swarm_bar_xlocs_adjustleft = {'right': 0, 'left': -0.4, 'center': -0.2}\n", + " # swarm_bar_xlocs_adjustright = {'right': -0.1, 'left': -0.1, 'center': -0.1}\n", + "\n", + " if isinstance(plot_data[xvar].dtype, pd.CategoricalDtype):\n", + " swarm_bars_order = pd.unique(plot_data[xvar]).categories\n", + " else:\n", + " swarm_bars_order = pd.unique(plot_data[xvar])\n", + "\n", + " swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order)\n", + " # swarm_bars_colors = [swarm_bars_kwargs.get('color')]*(max(swarm_bars_order)+1) if swarm_bars_kwargs.get('color') is not None else ['black']*(len(swarm_bars_order)+1) if color_col is not None or is_paired else swarm_colors\n", + " swarm_bars_colors = (\n", + " [swarm_bars_kwargs.get('color')] * (max(swarm_bars_order) + 1) \n", + " if swarm_bars_kwargs.get('color') is not None \n", + " else ['black']*(len(swarm_bars_order)+1)\n", + " if color_col is not None or is_paired\n", + " else list(plot_palette_raw.values())\n", + " )\n", + " swarm_bars_kwargs.pop('color')\n", + " for swarm_bars_x,swarm_bars_y,c in zip(np.arange(0,len(swarm_bars_order)+1,1), swarm_means, swarm_bars_colors):\n", + " ax.add_patch(mpatches.Rectangle((swarm_bars_x-0.25,0),\n", + " 0.5, swarm_bars_y, zorder=-1,color=c,**swarm_bars_kwargs))\n", + "\n", + "def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: object, ticks_to_plot: list, delta_text_kwargs: dict, color_col: str, \n", + " swarm_colors: list, is_paired: bool, proportional: bool, float_contrast: bool,\n", + " show_mini_meta: bool, mini_meta_delta: object, show_delta2: bool, delta_delta: object):\n", + " \"\"\"\n", + " Add text to the contrast plot.\n", + "\n", + " Parameters\n", + " ----------\n", + " results : object (Dataframe)\n", + " Dataframe of contrast object comparisons.\n", + " ax_to_plot : object\n", + " Matplotlib axis object to plot on.\n", + " swarm_plot_ax : object\n", + " Matplotlib axis object of the swarm plot.\n", + " ticks_to_plot : list\n", + " List of indices of the contrast objects.\n", + " delta_text_kwargs : dict\n", + " Keyword arguments for the delta text.\n", + " color_col : str\n", + " Column name of the color column.\n", + " swarm_colors : list\n", + " List of colors used in the plot.\n", + " is_paired : bool\n", + " Whether the data is paired.\n", + " proportional : bool\n", + " Whether the data is proportional.\n", + " float_contrast : bool\n", + " Whether the DABEST plot uses Gardner-Altman or Cummings\n", + " show_mini_meta : bool\n", + " Whether to show the mini meta-analysis.\n", + " mini_meta_delta : object\n", + " Mini meta-analysis object.\n", + " show_delta2 : bool\n", + " Whether to show the delta-delta.\n", + " delta_delta : object\n", + " delta-delta object.\n", + " \"\"\"\n", + " # Begin checks\n", + " delta_text_x_location = delta_text_kwargs.get('x_location')\n", + " if delta_text_x_location != 'right' and delta_text_x_location != 'left':\n", + " raise ValueError(\"delta_text_kwargs['x_location'] must be either 'right' or 'left'.\")\n", + " if float_contrast:\n", + " delta_text_x_location = 'left'\n", + " delta_text_kwargs[\"va\"] = 'bottom' if results.difference[0] >= 0 else 'top'\n", + " delta_text_kwargs.pop('x_location')\n", + "\n", + " delta_text_colors = [delta_text_kwargs.get('color')]*(max(ticks_to_plot)+1) if delta_text_kwargs.get('color') is not None else ['black']*(max(ticks_to_plot)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors\n", + " if show_mini_meta or show_delta2: delta_text_colors.append('black')\n", + " delta_text_kwargs.pop('color')\n", + "\n", + " total_ticks = len(ticks_to_plot) + 1 if show_mini_meta or show_delta2 else len(ticks_to_plot)\n", + "\n", + " # Collect the Y-values for the delta text\n", + " Delta_Values = []\n", + " for j, tick in enumerate(ticks_to_plot):\n", + " Delta_Values.append(results.difference[j])\n", + " if show_delta2: Delta_Values.append(delta_delta.difference)\n", + " if show_mini_meta: Delta_Values.append(mini_meta_delta.difference)\n", + "\n", + " # Collect the X-coordinates for the delta text\n", + " delta_text_x_coordinates = delta_text_kwargs.get('x_coordinates')\n", + "\n", + " if delta_text_x_coordinates is not None:\n", + " if not isinstance(delta_text_x_coordinates, list):\n", + " raise TypeError(\"delta_text_kwargs['x_coordinates'] must be a list of x-coordinates.\")\n", + " if len(delta_text_x_coordinates) != len(total_ticks):\n", + " raise ValueError(\"delta_text_kwargs['x_coordinates'] must have the same length as the number of ticks to plot.\")\n", + " else:\n", + " delta_text_x_coordinates = ticks_to_plot\n", + " X_Adjust = 0.48 if delta_text_x_location == 'right' else -0.38\n", + " delta_text_x_coordinates = [x+X_Adjust for x in delta_text_x_coordinates]\n", + " if show_mini_meta: delta_text_x_coordinates.append(max(swarm_plot_ax.get_xticks())+2+X_Adjust)\n", + " if show_delta2: delta_text_x_coordinates.append(max(swarm_plot_ax.get_xticks())+2-0.35)\n", + " if show_mini_meta or show_delta2: ticks_to_plot.append(max(ticks_to_plot)+1)\n", + " delta_text_kwargs.pop('x_coordinates')\n", + "\n", + " # Collect the Y-coordinates for the delta text\n", + " delta_text_y_coordinates = delta_text_kwargs.get('y_coordinates')\n", + "\n", + " if delta_text_y_coordinates is not None:\n", + " if not isinstance(delta_text_y_coordinates, list):\n", + " raise TypeError(\"delta_text_kwargs['y_coordinates'] must be a list of y-coordinates.\")\n", + " if len(delta_text_y_coordinates) != len(total_ticks):\n", + " raise ValueError(\"delta_text_kwargs['y_coordinates'] must have the same length as the number of ticks to plot.\")\n", + " else:\n", + " delta_text_y_coordinates = Delta_Values\n", + "\n", + " delta_text_kwargs.pop('y_coordinates')\n", + "\n", + " # Plot the delta text\n", + " for x,y,t,tick in zip(delta_text_x_coordinates, delta_text_y_coordinates,Delta_Values,ticks_to_plot):\n", + " Delta_Text = np.format_float_positional(t, precision=2, sign=True, trim=\"k\", min_digits=2)\n", + " ax_to_plot.text(x, y, Delta_Text, color=delta_text_colors[tick], zorder=5, **delta_text_kwargs)\n", + "\n", + "\n", + "def DeltaDotsPlotter(plot_data, contrast_axes, delta_id_col, idx, xvar, yvar, is_paired, color_col, float_contrast, plot_palette_raw, delta_dot_kwargs):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " plot_data : object (Dataframe)\n", + " Dataframe of the plot data.\n", + " contrast_axes : object\n", + " Matplotlib axis object to plot on.\n", + " delta_id_col : str\n", + " Column name of the delta id column.\n", + " idx : list\n", + " List of indices of the contrast objects.\n", + " xvar : str\n", + " Column name of the x variable.\n", + " yvar : str\n", + " Column name of the y variable.\n", + " is_paired : bool\n", + " Whether the data is paired.\n", + " color_col : str\n", + " Column name of the color column.\n", + " float_contrast : bool\n", + " Whether the DABEST plot uses Gardner-Altman or Cummings\n", + " plot_palette_raw : dict\n", + " Dictionary of colors used in the plot.\n", + " delta_dot_kwargs : dict\n", + " Keyword arguments for the delta dots.\n", + " \"\"\"\n", + " \n", + " # Checks and initializations\n", + " from .plot_tools import swarmplot\n", + "\n", + " if color_col is not None:\n", + " plot_palette_deltapts = plot_palette_raw\n", + " delta_plot_data = plot_data[[xvar, yvar, delta_id_col, color_col]]\n", + " else:\n", + " plot_palette_deltapts = \"k\"\n", + " delta_plot_data = plot_data[[xvar, yvar, delta_id_col]]\n", + "\n", + " # TODO: to make jitter value more accurate and not just a hardcoded eyeball value\n", + " jitter = 0.6 if float_contrast else 1 \n", + "\n", + " # Create dataframe of delta values\n", + " final_deltas = pd.DataFrame()\n", + " for i in idx:\n", + " for j in i:\n", + " if i.index(j) != 0:\n", + " temp_df_exp = delta_plot_data[\n", + " delta_plot_data[xvar].str.contains(j)\n", + " ].reset_index(drop=True)\n", + " if is_paired == \"baseline\":\n", + " temp_df_cont = delta_plot_data[\n", + " delta_plot_data[xvar].str.contains(i[0])\n", + " ].reset_index(drop=True)\n", + " elif is_paired == \"sequential\":\n", + " temp_df_cont = delta_plot_data[\n", + " delta_plot_data[xvar].str.contains(\n", + " i[i.index(j) - 1]\n", + " )\n", + " ].reset_index(drop=True)\n", + " delta_df = temp_df_exp.copy()\n", + " delta_df[yvar] = temp_df_exp[yvar] - temp_df_cont[yvar]\n", + " final_deltas = pd.concat([final_deltas, delta_df])\n", + "\n", + " # Plot the delta dots\n", + " swarmplot(\n", + " data=final_deltas,\n", + " x=xvar,\n", + " y=yvar,\n", + " ax=contrast_axes,\n", + " order=None,\n", + " hue=color_col,\n", + " palette=plot_palette_deltapts,\n", + " jitter=jitter,\n", + " is_drop_gutter=True,\n", + " gutter_limit=1,\n", + " **delta_dot_kwargs)\n", + " contrast_axes.legend().set_visible(False)\n", + "\n", + "\n", + "def slopegraph_plotter(dabest_obj, plot_data, xvar, yvar, color_col, plot_palette_raw, slopegraph_kwargs, rawdata_axes, ytick_color, temp_idx):\n", + " \n", + " # Pivot the long (melted) data.\n", + " if color_col is None:\n", + " pivot_values = [yvar]\n", + " else:\n", + " pivot_values = [yvar, color_col]\n", + " pivoted_plot_data = pd.pivot(\n", + " data=plot_data,\n", + " index=dabest_obj.id_col,\n", + " columns=xvar,\n", + " values=pivot_values,\n", + " )\n", + "\n", + " x_start = 0\n", + " for ii, current_tuple in enumerate(temp_idx):\n", + " current_pair = pivoted_plot_data.loc[\n", + " :, pd.MultiIndex.from_product([pivot_values, current_tuple])\n", + " ].dropna()\n", + " grp_count = len(current_tuple)\n", + " # Iterate through the data for the current tuple.\n", + " for ID, observation in current_pair.iterrows():\n", + " x_points = [t for t in range(x_start, x_start + grp_count)]\n", + " y_points = observation[yvar].tolist()\n", + "\n", + " if color_col is None:\n", + " slopegraph_kwargs[\"color\"] = ytick_color\n", + " else:\n", + " color_key = observation[color_col][0]\n", + " if isinstance(color_key, (str, np.int64, np.float64)):\n", + " slopegraph_kwargs[\"color\"] = plot_palette_raw[color_key]\n", + " slopegraph_kwargs[\"label\"] = color_key\n", + "\n", + " rawdata_axes.plot(x_points, y_points, **slopegraph_kwargs)\n", + "\n", + " x_start = x_start + grp_count\n", + "\n", + "def plot_minimeta_or_deltadelta_violins(show_mini_meta, effectsize_df, ci_type, rawdata_axes,\n", + " contrast_axes, violinplot_kwargs, halfviolin_alpha, ytick_color, \n", + " es_marker_size, group_summary_kwargs, contrast_xtick_labels, effect_size\n", + " ):\n", + " if show_mini_meta:\n", + " mini_meta_delta = effectsize_df.mini_meta_delta\n", + " data = mini_meta_delta.bootstraps_weighted_delta\n", + " difference = mini_meta_delta.difference\n", + " if ci_type == \"bca\":\n", + " ci_low = mini_meta_delta.bca_low\n", + " ci_high = mini_meta_delta.bca_high\n", + " else:\n", + " ci_low = mini_meta_delta.pct_low\n", + " ci_high = mini_meta_delta.pct_high\n", + " else:\n", + " delta_delta = effectsize_df.delta_delta\n", + " data = delta_delta.bootstraps_delta_delta\n", + " difference = delta_delta.difference\n", + " if ci_type == \"bca\":\n", + " ci_low = delta_delta.bca_low\n", + " ci_high = delta_delta.bca_high\n", + " else:\n", + " ci_low = delta_delta.pct_low\n", + " ci_high = delta_delta.pct_high\n", + " # Create the violinplot.\n", + " # New in v0.2.6: drop negative infinities before plotting.\n", + " position = max(rawdata_axes.get_xticks()) + 2\n", + " v = contrast_axes.violinplot(\n", + " data[~np.isinf(data)], positions=[position], **violinplot_kwargs\n", + " )\n", + "\n", + " fc = \"grey\"\n", + "\n", + " halfviolin(v, fill_color=fc, alpha=halfviolin_alpha)\n", + "\n", + " # Plot the effect size.\n", + " contrast_axes.plot(\n", + " [position],\n", + " difference,\n", + " marker=\"o\",\n", + " color=ytick_color,\n", + " markersize=es_marker_size,\n", + " )\n", + " # Plot the confidence interval.\n", + " contrast_axes.plot(\n", + " [position, position],\n", + " [ci_low, ci_high],\n", + " linestyle=\"-\",\n", + " color=ytick_color,\n", + " linewidth=group_summary_kwargs[\"lw\"],\n", + " )\n", + " if show_mini_meta:\n", + " contrast_xtick_labels.extend([\"\", \"Weighted delta\"])\n", + " elif effect_size == \"delta_g\":\n", + " contrast_xtick_labels.extend([\"\", \"deltas' g\"])\n", + " else:\n", + " contrast_xtick_labels.extend([\"\", \"delta-delta\"])\n", + " \n", + " return contrast_xtick_labels\n", + "\n", + "\n", + "def effect_size_curve_plotter(ticks_to_plot, results, ci_type, contrast_axes, violinplot_kwargs, halfviolin_alpha, \n", + " ytick_color, es_marker_size, group_summary_kwargs, bootstraps_color_by_group, plot_palette_contrast):\n", + " contrast_xtick_labels = []\n", + " for j, tick in enumerate(ticks_to_plot):\n", + " current_group = results.test[j]\n", + " current_control = results.control[j]\n", + " current_bootstrap = results.bootstraps[j]\n", + " current_effsize = results.difference[j]\n", + " if ci_type == \"bca\":\n", + " current_ci_low = results.bca_low[j]\n", + " current_ci_high = results.bca_high[j]\n", + " else:\n", + " current_ci_low = results.pct_low[j]\n", + " current_ci_high = results.pct_high[j]\n", + "\n", + " # Create the violinplot.\n", + " # New in v0.2.6: drop negative infinities before plotting.\n", + " v = contrast_axes.violinplot(\n", + " current_bootstrap[~np.isinf(current_bootstrap)],\n", + " positions=[tick],\n", + " **violinplot_kwargs\n", + " )\n", + " # Turn the violinplot into half, and color it the same as the swarmplot.\n", + " # Do this only if the color column is not specified.\n", + " # Ideally, the alpha (transparency) fo the violin plot should be\n", + " # less than one so the effect size and CIs are visible.\n", + " if bootstraps_color_by_group:\n", + " fc = plot_palette_contrast[current_group]\n", + " else:\n", + " fc = \"grey\"\n", + "\n", + " halfviolin(v, fill_color=fc, alpha=halfviolin_alpha)\n", + "\n", + " # Plot the effect size.\n", + " contrast_axes.plot(\n", + " [tick],\n", + " current_effsize,\n", + " marker=\"o\",\n", + " color=ytick_color,\n", + " markersize=es_marker_size,\n", + " )\n", + "\n", + " # Plot the confidence interval.\n", + " contrast_axes.plot(\n", + " [tick, tick],\n", + " [current_ci_low, current_ci_high],\n", + " linestyle=\"-\",\n", + " color=ytick_color,\n", + " linewidth=group_summary_kwargs[\"lw\"],\n", + " )\n", + "\n", + " contrast_xtick_labels.append(\n", + " \"{}\\nminus\\n{}\".format(current_group, current_control)\n", + " )\n", + " return current_group, current_control, current_effsize, contrast_xtick_labels\n", + "\n", + "\n", + "def grid_key_WIP(is_paired, idx, all_plot_groups, gridkey_rows, rawdata_axes, contrast_axes,\n", + " plot_data, xvar, yvar, results, show_delta2, show_mini_meta, float_contrast, plot_kwargs,):\n", + " \n", + " gridkey_show_Ns=plot_kwargs[\"gridkey_show_Ns\"]\n", + " gridkey_show_es=plot_kwargs[\"gridkey_show_es\"]\n", + " gridkey_merge_pairs=plot_kwargs[\"gridkey_merge_pairs\"]\n", + " \n", + " # Raise error if there are more than 2 items in any idx and gridkey_merge_pairs is True and is_paired is not None\n", + " if gridkey_merge_pairs and is_paired is not None:\n", + " for i in idx:\n", + " if len(i) > 2:\n", + " warnings.warn(\n", + " \"gridkey_merge_pairs=True only works if all idx in tuples have only two items. gridkey_merge_pairs has automatically been set to False\"\n", + " )\n", + " gridkey_merge_pairs = False\n", + " break\n", + " elif gridkey_merge_pairs and is_paired is None:\n", + " warnings.warn(\n", + " \"gridkey_merge_pairs=True is only applicable for paired data.\"\n", + " )\n", + " gridkey_merge_pairs = False\n", + "\n", + " # Checks for gridkey_merge_pairs and is_paired; if both are true, \"merges\" the gridkey per pair\n", + " if gridkey_merge_pairs and is_paired is not None:\n", + " groups_for_gridkey = []\n", + " for i in idx:\n", + " groups_for_gridkey.append(i[1])\n", + " else:\n", + " groups_for_gridkey = all_plot_groups\n", + "\n", + " # raise errors if gridkey_rows is not a list, or if the list is empty\n", + " if isinstance(gridkey_rows, list) is False:\n", + " raise TypeError(\"gridkey_rows must be a list.\")\n", + " elif len(gridkey_rows) == 0:\n", + " warnings.warn(\"gridkey_rows is an empty list.\")\n", + "\n", + " # raise Warning if an item in gridkey_rows is not contained in any idx\n", + " for i in gridkey_rows:\n", + " in_idx = 0\n", + " for j in groups_for_gridkey:\n", + " if i in j:\n", + " in_idx += 1\n", + " if in_idx == 0:\n", + " if is_paired is not None:\n", + " warnings.warn(\n", + " i\n", + " + \" is not in any idx. Please check. Alternatively, merging gridkey pairs may not be suitable for your data; try passing gridkey_merge_pairs=False.\"\n", + " )\n", + " else:\n", + " warnings.warn(i + \" is not in any idx. Please check.\")\n", + "\n", + " # Populate table: checks if idx for each column contains rowlabel name\n", + " # IF so, marks that element as present w black dot, or space if not present\n", + " table_cellcols = []\n", + " for i in gridkey_rows:\n", + " thisrow = []\n", + " for q in groups_for_gridkey:\n", + " if str(i) in q:\n", + " thisrow.append(\"\\u25CF\")\n", + " else:\n", + " thisrow.append(\"\")\n", + " table_cellcols.append(thisrow)\n", + "\n", + " # Adds a row for Ns with the Ns values\n", + " if gridkey_show_Ns:\n", + " gridkey_rows.append(\"Ns\")\n", + " list_of_Ns = []\n", + " for i in groups_for_gridkey:\n", + " list_of_Ns.append(str(plot_data.groupby(xvar).count()[yvar].loc[i]))\n", + " table_cellcols.append(list_of_Ns)\n", + "\n", + " # Adds a row for effectsizes with effectsize values\n", + " if gridkey_show_es:\n", + " gridkey_rows.append(\"\\u0394\")\n", + " effsize_list = []\n", + " results_list = results.test.to_list()\n", + "\n", + " # get the effect size, append + or -, 2 dec places\n", + " for i in enumerate(groups_for_gridkey):\n", + " if i[1] in results_list:\n", + " curr_esval = results.loc[results[\"test\"] == i[1]][\n", + " \"difference\"\n", + " ].iloc[0]\n", + " curr_esval_str = np.format_float_positional(\n", + " curr_esval,\n", + " precision=2,\n", + " sign=True,\n", + " trim=\"k\",\n", + " min_digits=2,\n", + " )\n", + " effsize_list.append(curr_esval_str)\n", + " else:\n", + " effsize_list.append(\"-\")\n", + "\n", + " table_cellcols.append(effsize_list)\n", + "\n", + " # If Gardner-Altman plot, plot on raw data and not contrast axes\n", + " if float_contrast:\n", + " axes_ploton = rawdata_axes\n", + " else:\n", + " axes_ploton = contrast_axes\n", + "\n", + " # Account for extended x axis in case of show_delta2 or show_mini_meta\n", + " x_groups_for_width = len(groups_for_gridkey)\n", + " if show_delta2 or show_mini_meta:\n", + " x_groups_for_width += 2\n", + " gridkey_width = len(groups_for_gridkey) / x_groups_for_width\n", + "\n", + " gridkey = axes_ploton.table(\n", + " cellText=table_cellcols,\n", + " rowLabels=gridkey_rows,\n", + " cellLoc=\"center\",\n", + " bbox=[\n", + " 0,\n", + " -len(gridkey_rows) * 0.1 - 0.05,\n", + " gridkey_width,\n", + " len(gridkey_rows) * 0.1,\n", + " ],\n", + " **{\"alpha\": 0.5}\n", + " )\n", + "\n", + " # modifies row label cells\n", + " for cell in gridkey._cells:\n", + " if cell[1] == -1:\n", + " gridkey._cells[cell].visible_edges = \"open\"\n", + " gridkey._cells[cell].set_text_props(**{\"ha\": \"right\"})\n", + "\n", + " # turns off both x axes\n", + " rawdata_axes.get_xaxis().set_visible(False)\n", + " contrast_axes.get_xaxis().set_visible(False)\n", + "\n", + "def barplotter(xvar, yvar, all_plot_groups, rawdata_axes, plot_data, bar_color, plot_palette_bar, plot_kwargs, barplot_kwargs):\n", + " # Plot the raw data as a barplot.\n", + " bar1_df = pd.DataFrame(\n", + " {xvar: all_plot_groups, \"proportion\": np.ones(len(all_plot_groups))}\n", + " )\n", + " bar1 = sns.barplot(\n", + " data=bar1_df,\n", + " x=xvar,\n", + " y=\"proportion\",\n", + " ax=rawdata_axes,\n", + " order=all_plot_groups,\n", + " linewidth=2,\n", + " facecolor=(1, 1, 1, 0),\n", + " edgecolor=bar_color,\n", + " zorder=1,\n", + " )\n", + " bar2 = sns.barplot(\n", + " data=plot_data,\n", + " x=xvar,\n", + " y=yvar,\n", + " ax=rawdata_axes,\n", + " order=all_plot_groups,\n", + " palette=plot_palette_bar,\n", + " zorder=1,\n", + " **barplot_kwargs\n", + " )\n", + " # adjust the width of bars\n", + " bar_width = plot_kwargs[\"bar_width\"]\n", + " for bar in bar1.patches:\n", + " x = bar.get_x()\n", + " width = bar.get_width()\n", + " centre = x + width / 2.0\n", + " bar.set_x(centre - bar_width / 2.0)\n", + " bar.set_width(bar_width)" ] }, { @@ -854,6 +1553,7 @@ " size: float = 5,\n", " side: str = \"center\",\n", " jitter: float = 1,\n", + " filled: Union[bool, List, Tuple] = True,\n", " is_drop_gutter: bool = True,\n", " gutter_limit: float = 0.5,\n", " **kwargs,\n", @@ -886,6 +1586,11 @@ " The side on which points are swarmed (\"center\", \"left\", or \"right\"). Default is \"center\".\n", " jitter : int | float\n", " Determines the distance between points. Default is 1.\n", + " filled : bool | List | Tuple\n", + " Determines whether the dots in the swarmplot are filled or not. If set to False,\n", + " dots are not filled. If provided as a List or Tuple, it should contain boolean values,\n", + " each corresponding to a swarm group in order, indicating whether the dot should be\n", + " filled or not.\n", " is_drop_gutter : bool\n", " If True, drop points that hit the gutters; otherwise, readjust them.\n", " gutter_limit : int | float\n", @@ -899,7 +1604,7 @@ " Matplotlib AxesSubplot object for which the swarm plot has been drawn on.\n", " \"\"\"\n", " s = SwarmPlot(data, x, y, ax, order, hue, palette, zorder, size, side, jitter)\n", - " ax = s.plot(is_drop_gutter, gutter_limit, ax, **kwargs)\n", + " ax = s.plot(is_drop_gutter, gutter_limit, ax, filled, **kwargs)\n", " return ax\n", "\n", "\n", @@ -1059,7 +1764,9 @@ " if not isinstance(self.__jitter, (int, float)):\n", " raise ValueError(\"`jitter` must be a scalar or float.\")\n", " if not isinstance(self.__palette, (str, Iterable)):\n", - " raise ValueError(\"`palette` must be either a string indicating a color name or an Iterable.\")\n", + " raise ValueError(\n", + " \"`palette` must be either a string indicating a color name or an Iterable.\"\n", + " )\n", " if self.__hue is not None and not isinstance(self.__hue, str):\n", " raise ValueError(\"`hue` must be either a string or None.\")\n", " if self.__order is not None and not isinstance(self.__order, Iterable):\n", @@ -1089,7 +1796,6 @@ " err = \"`palette` cannot be an empty string. It must be either a string indicating a color name or an Iterable.\"\n", " raise ValueError(err)\n", " if isinstance(self.__palette, dict):\n", - " # TODO: to add detection of when dict length is less than size of unique_items\n", " for group_i, color_i in self.__palette.items():\n", " if group_i not in pd.unique(data[color_col]):\n", " err = (\n", @@ -1099,8 +1805,10 @@ " )\n", " raise IndexError(err)\n", " if isinstance(color_i, str) and color_i.strip() == \"\":\n", - " err = \"The color mapping for {0} in `palette` is an empty string. It must contain a color name.\".format(group_i)\n", - " raise ValueError(err) \n", + " err = \"The color mapping for {0} in `palette` is an empty string. It must contain a color name.\".format(\n", + " group_i\n", + " )\n", + " raise ValueError(err)\n", "\n", " if side.lower() not in [\"center\", \"right\", \"left\"]:\n", " raise ValueError(\n", @@ -1302,7 +2010,12 @@ " return points_data\n", "\n", " def plot(\n", - " self, is_drop_gutter: bool, gutter_limit: float, ax: axes.Subplot, **kwargs\n", + " self,\n", + " is_drop_gutter: bool,\n", + " gutter_limit: float,\n", + " ax: axes.Subplot,\n", + " filled: Union[bool, List, Tuple],\n", + " **kwargs,\n", " ) -> axes.Subplot:\n", " \"\"\"\n", " Generate a swarm plot.\n", @@ -1315,6 +2028,11 @@ " The limit for points hitting the gutters.\n", " ax : axes.Subplot\n", " The matplotlib figure object to which the swarm plot will be added.\n", + " filled : bool | List | Tuple\n", + " Determines whether the dots in the swarmplot are filled or not. If set to False,\n", + " dots are not filled. If provided as a List or Tuple, it should contain boolean values,\n", + " each corresponding to a swarm group in order, indicating whether the dot should be\n", + " filled or not.\n", " **kwargs:\n", " Additional keyword arguments to be passed to the scatter plot.\n", "\n", @@ -1328,12 +2046,28 @@ " raise ValueError(\"`is_drop_gutter` must be a boolean.\")\n", " if not isinstance(gutter_limit, (int, float)):\n", " raise ValueError(\"`gutter_limit` must be a scalar or float.\")\n", + " if not isinstance(filled, (bool, list, tuple)):\n", + " raise ValueError(\"`filled` must be a boolean, list or tuple.\")\n", + "\n", + " # More thorough input validation checks\n", + " if isinstance(filled, (list, tuple)):\n", + " if len(filled) != len(self.__order):\n", + " err = (\n", + " \"There are {0} unique values in `x` column in `data` \"\n", + " \"but `filled` has a length of {1}. If `filled` is a list \"\n", + " \"or a tuple, it must have the same length as the number of \"\n", + " \"unique values/groups in the `x` column of data.\"\n", + " ).format(len(self.__order), len(filled))\n", + " raise ValueError(err)\n", + " if not all(isinstance(_, bool) for _ in filled):\n", + " raise ValueError(\"All values in `filled` must be a boolean.\")\n", "\n", " # Assumptions are that self.__data_copy is already sorted according to self.__order\n", " x_position = (\n", " 0 # x-coordinate of center of each individual swarm of the swarm plot\n", " )\n", " x_tick_tabels = []\n", + "\n", " for group_i, values_i in self.__data_copy.groupby(self.__x):\n", " x_new = []\n", " values_i_y = values_i[self.__y]\n", @@ -1371,6 +2105,10 @@ " cmap = []\n", " for cmap_group_i in cmap_values:\n", " cmap.append(self.__palette[cmap_group_i])\n", + "\n", + " # WIP: legend for swarm plot\n", + " swarm_legend_kwargs = {'colors':cmap, 'labels':cmap_values, 'index':index}\n", + "\n", " cmap = ListedColormap(cmap)\n", " ax.scatter(\n", " values_i[\"x_new\"],\n", @@ -1382,23 +2120,55 @@ " edgecolor=\"face\",\n", " **kwargs,\n", " )\n", + "\n", " else:\n", " # color swarms based on `x` column\n", + " if not isinstance(filled, bool):\n", + " facecolor = (\n", + " \"none\"\n", + " if not filled[x_position - 1]\n", + " else self.__palette[group_i]\n", + " )\n", + " else:\n", + " facecolor = \"none\" if not filled else self.__palette[group_i]\n", + "\n", " ax.scatter(\n", " values_i[\"x_new\"],\n", " values_i[self.__y],\n", " s=self.__size,\n", - " c=self.__palette[group_i],\n", " zorder=self.__zorder,\n", - " edgecolor=\"face\",\n", + " facecolor=facecolor,\n", + " edgecolor=self.__palette[group_i],\n", + " label=group_i,\n", " **kwargs,\n", " )\n", "\n", + " # Handling of legends\n", + " # This is currently a workaround because c and cmap is unable to map the labels when calling scatter()\n", + " # labels has to be used to designate legend labels and handles in scatter() due to the potential calling of ax.get_legend_handles_labels()\n", + " if self.__hue is not None:\n", + " for cmap_group_i in self.__palette:\n", + " ax.scatter(\n", + " [],\n", + " [],\n", + " c=self.__palette[cmap_group_i],\n", + " label=cmap_group_i,\n", + " )\n", + " handles, labels = ax.get_legend_handles_labels()\n", + "\n", " ax.get_xaxis().set_ticks(np.arange(x_position))\n", " ax.get_xaxis().set_ticklabels(x_tick_tabels)\n", "\n", - " return ax" + " return ax, swarm_legend_kwargs if self.__hue is not None else None" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "022ea903", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/nbs/API/plotter.ipynb b/nbs/API/plotter.ipynb index 7e054ea4..8b5f365c 100644 --- a/nbs/API/plotter.ipynb +++ b/nbs/API/plotter.ipynb @@ -59,6 +59,8 @@ "import seaborn as sns\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", + "import matplotlib.patches as mpatches\n", + "from matplotlib.lines import Line2D\n", "import pandas as pd\n", "import warnings\n", "import logging" @@ -111,19 +113,41 @@ " title=None, fontsize_title=16,\n", " fontsize_rawxlabel=12, fontsize_rawylabel=12,\n", " fontsize_contrastxlabel=12, fontsize_contrastylabel=12,\n", - " fontsize_delta2label=12\n", + " fontsize_delta2label=12,\n", + " swarm_bars=True, swarm_bars_kwargs=None,\n", + " contrast_bars=True, contrast_bars_kwargs=None,\n", + " delta_text=True, delta_text_kwargs=None,\n", + " delta_dot=True, delta_dot_kwargs=None,\n", " \"\"\"\n", - " from .misc_tools import merge_two_dicts\n", + " from .misc_tools import (\n", + " get_params,\n", + " get_kwargs,\n", + " get_color_palette,\n", + " initialize_fig,\n", + " get_plot_groups,\n", + " add_counts_to_ticks,\n", + " extract_contrast_plotting_ticks,\n", + " set_xaxis_ticks_and_lims,\n", + " show_legend,\n", + " Gardner_Altman_Plot_Aesthetic_Adjustments,\n", + " Cumming_Plot_Aesthetic_Adjustments,\n", + " General_Plot_Aesthetic_Adjustments,\n", + " )\n", " from .plot_tools import (\n", - " halfviolin,\n", " get_swarm_spans,\n", " error_bar,\n", " sankeydiag,\n", " swarmplot,\n", - " )\n", - " from ._stats_tools.effsize import (\n", - " _compute_standardizers,\n", - " _compute_hedges_correction_factor,\n", + " swarm_bars_plotter,\n", + " contrast_bars_plotter,\n", + " summary_bars_plotter,\n", + " delta_text_plotter,\n", + " DeltaDotsPlotter,\n", + " slopegraph_plotter,\n", + " plot_minimeta_or_deltadelta_violins,\n", + " effect_size_curve_plotter,\n", + " grid_key_WIP,\n", + " barplotter,\n", " )\n", "\n", " warnings.filterwarnings(\n", @@ -141,499 +165,95 @@ " original_rcParams[parameter] = plt.rcParams[parameter]\n", "\n", " plt.rcParams[\"axes.grid\"] = False\n", - "\n", " ytick_color = plt.rcParams[\"ytick.color\"]\n", - " face_color = plot_kwargs[\"face_color\"]\n", - "\n", - " if plot_kwargs[\"face_color\"] is None:\n", - " face_color = \"white\"\n", - "\n", - " dabest_obj = effectsize_df.dabest_obj\n", - " plot_data = effectsize_df._plot_data\n", - " xvar = effectsize_df.xvar\n", - " yvar = effectsize_df.yvar\n", - " is_paired = effectsize_df.is_paired\n", - " delta2 = effectsize_df.delta2\n", - " mini_meta = effectsize_df.mini_meta\n", - " effect_size = effectsize_df.effect_size\n", - " proportional = effectsize_df.proportional\n", - "\n", - " all_plot_groups = dabest_obj._all_plot_groups\n", - " idx = dabest_obj.idx\n", - "\n", - " if effect_size not in [\"mean_diff\", \"delta_g\"] or not delta2:\n", - " show_delta2 = False\n", - " else:\n", - " show_delta2 = plot_kwargs[\"show_delta2\"]\n", - "\n", - " if effect_size != \"mean_diff\" or not mini_meta:\n", - " show_mini_meta = False\n", - " else:\n", - " show_mini_meta = plot_kwargs[\"show_mini_meta\"]\n", - "\n", - " if show_delta2 and show_mini_meta:\n", - " err0 = \"`show_delta2` and `show_mini_meta` cannot be True at the same time.\"\n", - " raise ValueError(err0)\n", "\n", - " # Disable Gardner-Altman plotting if any of the idxs comprise of more than\n", - " # two groups or if it is a delta-delta plot.\n", - " float_contrast = plot_kwargs[\"float_contrast\"]\n", - " effect_size_type = effectsize_df.effect_size\n", - " if len(idx) > 1 or len(idx[0]) > 2:\n", - " float_contrast = False\n", - "\n", - " if effect_size_type in [\"cliffs_delta\"]:\n", - " float_contrast = False\n", - "\n", - " if show_delta2 or show_mini_meta:\n", - " float_contrast = False\n", - "\n", - " if not is_paired:\n", - " show_pairs = False\n", - " else:\n", - " show_pairs = plot_kwargs[\"show_pairs\"]\n", - "\n", - " # Set default kwargs first, then merge with user-dictated ones.\n", - " # Swarmplot kwargs\n", - " default_swarmplot_kwargs = {\"size\": plot_kwargs[\"raw_marker_size\"]}\n", - " if plot_kwargs[\"swarmplot_kwargs\"] is None:\n", - " swarmplot_kwargs = default_swarmplot_kwargs\n", - " else:\n", - " swarmplot_kwargs = merge_two_dicts(\n", - " default_swarmplot_kwargs, plot_kwargs[\"swarmplot_kwargs\"]\n", - " )\n", - " asymmetric_side = (\n", - " \"left\" # TODO: allow users to control side for swarms of swarmplot.\n", - " )\n", - "\n", - " # Barplot kwargs\n", - " default_barplot_kwargs = {\"estimator\": np.mean, \"errorbar\": plot_kwargs[\"ci\"]}\n", - "\n", - " if plot_kwargs[\"barplot_kwargs\"] is None:\n", - " barplot_kwargs = default_barplot_kwargs\n", - " else:\n", - " barplot_kwargs = merge_two_dicts(\n", - " default_barplot_kwargs, plot_kwargs[\"barplot_kwargs\"]\n", - " )\n", + " # Extract parameters and set kwargs\n", + " (dabest_obj, plot_data, xvar, yvar, is_paired, effect_size, \n", + " proportional, all_plot_groups, idx, show_delta2, show_mini_meta, \n", + " float_contrast, show_pairs, effect_size_type, group_summaries, err_color) = get_params(\n", + " effectsize_df=effectsize_df, \n", + " plot_kwargs=plot_kwargs\n", + " )\n", + "\n", + " (swarmplot_kwargs, barplot_kwargs, sankey_kwargs, violinplot_kwargs, \n", + " slopegraph_kwargs, reflines_kwargs, legend_kwargs, group_summary_kwargs, redraw_axes_kwargs, \n", + " delta_dot_kwargs, delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs) = get_kwargs(\n", + " plot_kwargs=plot_kwargs, \n", + " ytick_color=ytick_color\n", + " )\n", "\n", - " # Sankey Diagram kwargs\n", - " default_sankey_kwargs = {\n", - " \"width\": 0.4,\n", - " \"align\": \"center\",\n", - " \"sankey\": True,\n", - " \"flow\": True,\n", - " \"alpha\": 0.4,\n", - " \"rightColor\": False,\n", - " \"bar_width\": 0.2,\n", - " }\n", - " if plot_kwargs[\"sankey_kwargs\"] is None:\n", - " sankey_kwargs = default_sankey_kwargs\n", - " else:\n", - " sankey_kwargs = merge_two_dicts(\n", - " default_sankey_kwargs, plot_kwargs[\"sankey_kwargs\"]\n", - " )\n", " # We also need to extract the `sankey` and `flow` from the kwargs for plotter.py\n", " # to use for varying different kinds of paired proportional plots\n", " # We also don't want to pop the parameter from the kwargs\n", - " sankey = sankey_kwargs[\"sankey\"]\n", - " flow = sankey_kwargs[\"flow\"]\n", - "\n", - " # Violinplot kwargs.\n", - " default_violinplot_kwargs = {\n", - " \"widths\": 0.5,\n", - " \"vert\": True,\n", - " \"showextrema\": False,\n", - " \"showmedians\": False,\n", - " }\n", - " if plot_kwargs[\"violinplot_kwargs\"] is None:\n", - " violinplot_kwargs = default_violinplot_kwargs\n", - " else:\n", - " violinplot_kwargs = merge_two_dicts(\n", - " default_violinplot_kwargs, plot_kwargs[\"violinplot_kwargs\"]\n", - " )\n", - "\n", - " # Slopegraph kwargs.\n", - " default_slopegraph_kwargs = {\"linewidth\": 1, \"alpha\": 0.5}\n", - " if plot_kwargs[\"slopegraph_kwargs\"] is None:\n", - " slopegraph_kwargs = default_slopegraph_kwargs\n", - " else:\n", - " slopegraph_kwargs = merge_two_dicts(\n", - " default_slopegraph_kwargs, plot_kwargs[\"slopegraph_kwargs\"]\n", - " )\n", - "\n", - " # Zero reference-line kwargs.\n", - " default_reflines_kwargs = {\n", - " \"linestyle\": \"solid\",\n", - " \"linewidth\": 0.75,\n", - " \"zorder\": 2,\n", - " \"color\": ytick_color,\n", - " }\n", - " if plot_kwargs[\"reflines_kwargs\"] is None:\n", - " reflines_kwargs = default_reflines_kwargs\n", - " else:\n", - " reflines_kwargs = merge_two_dicts(\n", - " default_reflines_kwargs, plot_kwargs[\"reflines_kwargs\"]\n", - " )\n", - "\n", - " # Legend kwargs.\n", - " default_legend_kwargs = {\"loc\": \"upper left\", \"frameon\": False}\n", - " if plot_kwargs[\"legend_kwargs\"] is None:\n", - " legend_kwargs = default_legend_kwargs\n", - " else:\n", - " legend_kwargs = merge_two_dicts(\n", - " default_legend_kwargs, plot_kwargs[\"legend_kwargs\"]\n", - " )\n", - "\n", - " ################################################### GRIDKEY WIP - extracting arguments\n", - "\n", - " gridkey_rows = plot_kwargs[\"gridkey_rows\"]\n", - " gridkey_merge_pairs = plot_kwargs[\"gridkey_merge_pairs\"]\n", - " gridkey_show_Ns = plot_kwargs[\"gridkey_show_Ns\"]\n", - " gridkey_show_es = plot_kwargs[\"gridkey_show_es\"]\n", - "\n", - " if gridkey_rows is None:\n", - " gridkey_show_Ns = False\n", - " gridkey_show_es = False\n", - "\n", - " ################################################### END GRIDKEY WIP - extracting arguments\n", - "\n", - " # Group summaries kwargs.\n", - " gs_default = {\"mean_sd\", \"median_quartiles\", None}\n", - " if plot_kwargs[\"group_summaries\"] not in gs_default:\n", - " raise ValueError(\n", - " \"group_summaries must be one of\" \" these: {}.\".format(gs_default)\n", - " )\n", - "\n", - " default_group_summary_kwargs = {\"zorder\": 3, \"lw\": 2, \"alpha\": 1}\n", - " if plot_kwargs[\"group_summary_kwargs\"] is None:\n", - " group_summary_kwargs = default_group_summary_kwargs\n", - " else:\n", - " group_summary_kwargs = merge_two_dicts(\n", - " default_group_summary_kwargs, plot_kwargs[\"group_summary_kwargs\"]\n", - " )\n", - "\n", - " # Create color palette that will be shared across subplots.\n", - " color_col = plot_kwargs[\"color_col\"]\n", - " if color_col is None:\n", - " color_groups = pd.unique(plot_data[xvar])\n", - " bootstraps_color_by_group = True\n", - " else:\n", - " if color_col not in plot_data.columns:\n", - " raise KeyError(\"``{}`` is not a column in the data.\".format(color_col))\n", - " color_groups = pd.unique(plot_data[color_col])\n", - " bootstraps_color_by_group = False\n", - " if show_pairs:\n", - " bootstraps_color_by_group = False\n", - "\n", - " # Handle the color palette.\n", - " names = color_groups\n", - " n_groups = len(color_groups)\n", - " custom_pal = plot_kwargs[\"custom_palette\"]\n", - " swarm_desat = plot_kwargs[\"swarm_desat\"]\n", - " bar_desat = plot_kwargs[\"bar_desat\"]\n", - " contrast_desat = plot_kwargs[\"halfviolin_desat\"]\n", - "\n", - " if custom_pal is None:\n", - " unsat_colors = sns.color_palette(n_colors=n_groups)\n", - " else:\n", - " if isinstance(custom_pal, dict):\n", - " groups_in_palette = {\n", - " k: v for k, v in custom_pal.items() if k in color_groups\n", - " }\n", - "\n", - " names = groups_in_palette.keys()\n", - " unsat_colors = groups_in_palette.values()\n", - "\n", - " elif isinstance(custom_pal, list):\n", - " unsat_colors = custom_pal[0:n_groups]\n", - "\n", - " elif isinstance(custom_pal, str):\n", - " # check it is in the list of matplotlib palettes.\n", - " if custom_pal in plt.colormaps():\n", - " unsat_colors = sns.color_palette(custom_pal, n_groups)\n", - " else:\n", - " err1 = \"The specified `custom_palette` {}\".format(custom_pal)\n", - " err2 = \" is not a matplotlib palette. Please check.\"\n", - " raise ValueError(err1 + err2)\n", - "\n", - " if custom_pal is None and color_col is None:\n", - " swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]\n", - " plot_palette_raw = dict(zip(names.categories, swarm_colors))\n", - "\n", - " bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]\n", - " plot_palette_bar = dict(zip(names.categories, bar_color))\n", - "\n", - " contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]\n", - " plot_palette_contrast = dict(zip(names.categories, contrast_colors))\n", - "\n", - " # For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors\n", - " # default color palette will be set to \"hls\"\n", - " plot_palette_sankey = None\n", - "\n", - " else:\n", - " swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]\n", - " plot_palette_raw = dict(zip(names, swarm_colors))\n", - "\n", - " bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]\n", - " plot_palette_bar = dict(zip(names, bar_color))\n", - "\n", - " contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]\n", - " plot_palette_contrast = dict(zip(names, contrast_colors))\n", - "\n", - " plot_palette_sankey = custom_pal\n", - "\n", - " # Infer the figsize.\n", - " fig_size = plot_kwargs[\"fig_size\"]\n", - " if fig_size is None:\n", - " all_groups_count = np.sum([len(i) for i in dabest_obj.idx])\n", - " # Increase the width for delta-delta graph\n", - " if show_delta2 or show_mini_meta:\n", - " all_groups_count += 2\n", - " if is_paired and show_pairs and proportional is False:\n", - " frac = 0.75\n", - " else:\n", - " frac = 1\n", - " if float_contrast:\n", - " height_inches = 4\n", - " each_group_width_inches = 2.5 * frac\n", - " else:\n", - " height_inches = 6\n", - " each_group_width_inches = 1.5 * frac\n", - "\n", - " width_inches = each_group_width_inches * all_groups_count\n", - " fig_size = (width_inches, height_inches)\n", - "\n", - " # Initialise the figure.\n", - " init_fig_kwargs = dict(figsize=fig_size, dpi=plot_kwargs[\"dpi\"], tight_layout=True)\n", - "\n", - " width_ratios_ga = [2.5, 1]\n", - "\n", - " ###################### GRIDKEY HSPACE ALTERATION\n", - "\n", - " # Sets hspace for cummings plots if gridkey is shown.\n", - " if gridkey_rows is not None:\n", - " h_space_cummings = 0.1\n", - " else:\n", - " h_space_cummings = 0.3\n", - "\n", - " ###################### END GRIDKEY HSPACE ALTERATION\n", - "\n", - " if plot_kwargs[\"ax\"] is not None:\n", - " # New in v0.2.6.\n", - " # Use inset axes to create the estimation plot inside a single axes.\n", - " # Author: Adam L Nekimken. (PR #73)\n", - " rawdata_axes = plot_kwargs[\"ax\"]\n", - " ax_position = rawdata_axes.get_position() # [[x0, y0], [x1, y1]]\n", - "\n", - " fig = rawdata_axes.get_figure()\n", - " fig.patch.set_facecolor(face_color)\n", - "\n", - " if float_contrast:\n", - " axins = rawdata_axes.inset_axes(\n", - " [1, 0, width_ratios_ga[1] / width_ratios_ga[0], 1]\n", - " )\n", - " rawdata_axes.set_position( # [l, b, w, h]\n", - " [\n", - " ax_position.x0,\n", - " ax_position.y0,\n", - " (ax_position.x1 - ax_position.x0)\n", - " * (width_ratios_ga[0] / sum(width_ratios_ga)),\n", - " (ax_position.y1 - ax_position.y0),\n", - " ]\n", - " )\n", - "\n", - " contrast_axes = axins\n", - "\n", - " else:\n", - " axins = rawdata_axes.inset_axes([0, -1 - h_space_cummings, 1, 1])\n", - " plot_height = (ax_position.y1 - ax_position.y0) / (2 + h_space_cummings)\n", - " rawdata_axes.set_position(\n", - " [\n", - " ax_position.x0,\n", - " ax_position.y0 + (1 + h_space_cummings) * plot_height,\n", - " (ax_position.x1 - ax_position.x0),\n", - " plot_height,\n", - " ]\n", - " )\n", - "\n", - " contrast_axes = axins\n", - " rawdata_axes.contrast_axes = axins\n", - "\n", - " else:\n", - " # Here, we hardcode some figure parameters.\n", - " if float_contrast:\n", - " fig, axx = plt.subplots(\n", - " ncols=2,\n", - " gridspec_kw={\"width_ratios\": width_ratios_ga, \"wspace\": 0},\n", - " **init_fig_kwargs\n", - " )\n", - " fig.patch.set_facecolor(face_color)\n", - "\n", - " else:\n", - " fig, axx = plt.subplots(\n", - " nrows=2, gridspec_kw={\"hspace\": h_space_cummings}, **init_fig_kwargs\n", - " )\n", - " fig.patch.set_facecolor(face_color)\n", - "\n", - " # Title\n", - " title = plot_kwargs[\"title\"]\n", - " fontsize_title = plot_kwargs[\"fontsize_title\"]\n", - " if title is not None:\n", - " fig.suptitle(title, fontsize=fontsize_title)\n", - " rawdata_axes = axx[0]\n", - " contrast_axes = axx[1]\n", - " rawdata_axes.set_frame_on(False)\n", - " contrast_axes.set_frame_on(False)\n", - "\n", - " redraw_axes_kwargs = {\n", - " \"colors\": ytick_color,\n", - " \"facecolors\": ytick_color,\n", - " \"lw\": 1,\n", - " \"zorder\": 10,\n", - " \"clip_on\": False,\n", - " }\n", - "\n", - " swarm_ylim = plot_kwargs[\"swarm_ylim\"]\n", - "\n", - " if swarm_ylim is not None:\n", - " rawdata_axes.set_ylim(swarm_ylim)\n", - "\n", " one_sankey = (\n", " False if is_paired is not None else None\n", " ) # Flag to indicate if only one sankey is plotted.\n", " two_col_sankey = (\n", - " True if proportional and not one_sankey and sankey and not flow else False\n", + " True if proportional and not one_sankey and sankey_kwargs[\"sankey\"] and not sankey_kwargs[\"flow\"] else False\n", " )\n", "\n", - " if show_pairs:\n", - " # Determine temp_idx based on is_paired and proportional conditions\n", - " if is_paired == \"baseline\":\n", - " idx_pairs = [\n", - " (control, test)\n", - " for i in idx\n", - " for control, test in zip([i[0]] * (len(i) - 1), i[1:])\n", - " ]\n", - " temp_idx = idx if not proportional else idx_pairs\n", - " else:\n", - " idx_pairs = [\n", - " (control, test) for i in idx for control, test in zip(i[:-1], i[1:])\n", - " ]\n", - " temp_idx = idx if not proportional else idx_pairs\n", - "\n", - " # Determine temp_all_plot_groups based on proportional condition\n", - " plot_groups = [item for i in temp_idx for item in i]\n", - " temp_all_plot_groups = all_plot_groups if not proportional else plot_groups\n", + " # Extract Color palette\n", + " (color_col, bootstraps_color_by_group, n_groups, filled,\n", + " swarm_colors, plot_palette_raw, bar_color, \n", + " plot_palette_bar, plot_palette_contrast, plot_palette_sankey) = get_color_palette(\n", + " plot_kwargs=plot_kwargs, \n", + " plot_data=plot_data, \n", + " xvar=xvar, \n", + " show_pairs=show_pairs,\n", + " idx=idx\n", + " )\n", "\n", + " # Initialise the figure.\n", + " fig, rawdata_axes, contrast_axes, swarm_ylim = initialize_fig(\n", + " plot_kwargs=plot_kwargs, \n", + " dabest_obj=dabest_obj, \n", + " show_delta2=show_delta2, \n", + " show_mini_meta=show_mini_meta, \n", + " is_paired=is_paired, \n", + " show_pairs=show_pairs, \n", + " proportional=proportional, \n", + " float_contrast=float_contrast,\n", + " )\n", + " \n", + " # Plotting the rawdata.\n", + " if show_pairs:\n", + " temp_idx, temp_all_plot_groups = get_plot_groups(\n", + " is_paired=is_paired, \n", + " idx=idx, \n", + " proportional=proportional, \n", + " all_plot_groups=all_plot_groups\n", + " )\n", " if not proportional:\n", " # Plot the raw data as a slopegraph.\n", - " # Pivot the long (melted) data.\n", - " if color_col is None:\n", - " pivot_values = [yvar]\n", - " else:\n", - " pivot_values = [yvar, color_col]\n", - " pivoted_plot_data = pd.pivot(\n", - " data=plot_data,\n", - " index=dabest_obj.id_col,\n", - " columns=xvar,\n", - " values=pivot_values,\n", - " )\n", - " x_start = 0\n", - " for ii, current_tuple in enumerate(temp_idx):\n", - " current_pair = pivoted_plot_data.loc[\n", - " :, pd.MultiIndex.from_product([pivot_values, current_tuple])\n", - " ].dropna()\n", - " grp_count = len(current_tuple)\n", - " # Iterate through the data for the current tuple.\n", - " for ID, observation in current_pair.iterrows():\n", - " x_points = [t for t in range(x_start, x_start + grp_count)]\n", - " y_points = observation[yvar].tolist()\n", - "\n", - " if color_col is None:\n", - " slopegraph_kwargs[\"color\"] = ytick_color\n", - " else:\n", - " color_key = observation[color_col][0]\n", - " if isinstance(color_key, (str, np.int64, np.float64)):\n", - " slopegraph_kwargs[\"color\"] = plot_palette_raw[color_key]\n", - " slopegraph_kwargs[\"label\"] = color_key\n", - "\n", - " rawdata_axes.plot(x_points, y_points, **slopegraph_kwargs)\n", - "\n", - " x_start = x_start + grp_count\n", - "\n", - " ##################### DELTA PTS ON CONTRAST PLOT WIP\n", - "\n", - " contrast_show_deltas = plot_kwargs[\"contrast_show_deltas\"]\n", - "\n", - " if is_paired is None:\n", - " contrast_show_deltas = False\n", - "\n", - " if contrast_show_deltas:\n", - " delta_plot_data_temp = plot_data.copy()\n", - " delta_id_col = dabest_obj.id_col\n", - " if color_col is not None:\n", - " plot_palette_deltapts = plot_palette_raw\n", - " delta_plot_data = delta_plot_data_temp[\n", - " [xvar, yvar, delta_id_col, color_col]\n", - " ]\n", - " deltapts_args = {\n", - " \"marker\": \"^\",\n", - " \"alpha\": 0.5,\n", - " }\n", - "\n", - " else:\n", - " plot_palette_deltapts = \"k\"\n", - " delta_plot_data = delta_plot_data_temp[[xvar, yvar, delta_id_col]]\n", - " deltapts_args = {\"marker\": \"^\", \"alpha\": 0.5}\n", - "\n", - " final_deltas = pd.DataFrame()\n", - " for i in idx:\n", - " for j in i:\n", - " if i.index(j) != 0:\n", - " temp_df_exp = delta_plot_data[\n", - " delta_plot_data[xvar].str.contains(j)\n", - " ].reset_index(drop=True)\n", - " if is_paired == \"baseline\":\n", - " temp_df_cont = delta_plot_data[\n", - " delta_plot_data[xvar].str.contains(i[0])\n", - " ].reset_index(drop=True)\n", - " elif is_paired == \"sequential\":\n", - " temp_df_cont = delta_plot_data[\n", - " delta_plot_data[xvar].str.contains(\n", - " i[i.index(j) - 1]\n", - " )\n", - " ].reset_index(drop=True)\n", - " delta_df = temp_df_exp.copy()\n", - " delta_df[yvar] = temp_df_exp[yvar] - temp_df_cont[yvar]\n", - " final_deltas = pd.concat([final_deltas, delta_df])\n", - "\n", - " # swarmplot() plots swarms based on current size of ax\n", - " # Therefore, since the ax size for Gardner-Altman plot changes later on, there has to be decreased jitter\n", - " # TODO: to make jitter value more accurate and not just a hardcoded eyeball value\n", - " if float_contrast:\n", - " jitter = 0.6\n", - " else:\n", - " jitter = 1\n", - "\n", - " # Plot the raw data as a swarmplot.\n", - " deltapts_plot = swarmplot(\n", - " data=final_deltas,\n", - " x=xvar,\n", - " y=yvar,\n", - " ax=contrast_axes,\n", - " order=None,\n", - " hue=color_col,\n", - " palette=plot_palette_deltapts,\n", - " zorder=2,\n", - " size=3,\n", - " side=\"right\",\n", - " jitter=jitter,\n", - " is_drop_gutter=True,\n", - " gutter_limit=1,\n", - " **deltapts_args\n", + " slopegraph_plotter(\n", + " dabest_obj=dabest_obj, \n", + " plot_data=plot_data, \n", + " xvar=xvar, \n", + " yvar=yvar, \n", + " color_col=color_col, \n", + " plot_palette_raw=plot_palette_raw, \n", + " slopegraph_kwargs=slopegraph_kwargs, \n", + " rawdata_axes=rawdata_axes, \n", + " ytick_color=ytick_color, \n", + " temp_idx=temp_idx\n", " )\n", - " contrast_axes.legend().set_visible(False)\n", "\n", - " ##################### DELTA PTS ON CONTRAST PLOT END\n", + " # DELTA PTS ON CONTRAST PLOT WIP\n", + " show_delta_dots = plot_kwargs[\"delta_dot\"]\n", + " if show_delta_dots and is_paired is not None:\n", + " DeltaDotsPlotter(\n", + " plot_data=plot_data, \n", + " contrast_axes=contrast_axes, \n", + " delta_id_col=dabest_obj.id_col, \n", + " idx=idx, \n", + " xvar=xvar, \n", + " yvar=yvar, \n", + " is_paired=is_paired, \n", + " color_col=color_col, \n", + " float_contrast=float_contrast, \n", + " plot_palette_raw=plot_palette_raw, \n", + " delta_dot_kwargs=delta_dot_kwargs\n", + " )\n", "\n", " # Set the tick labels, because the slopegraph plotting doesn't.\n", " rawdata_axes.set_xticks(np.arange(0, len(temp_all_plot_groups)))\n", @@ -641,1013 +261,378 @@ "\n", " else:\n", " # Plot the raw data as a set of Sankey Diagrams aligned like barplot.\n", - " group_summaries = plot_kwargs[\"group_summaries\"]\n", - " if group_summaries is None:\n", - " group_summaries = \"mean_sd\"\n", - " err_color = plot_kwargs[\"err_color\"]\n", - " if err_color is None:\n", - " err_color = \"black\"\n", - "\n", - " if show_pairs:\n", - " sankey_control_group = []\n", - " sankey_test_group = []\n", - " # Design for Sankey Flow Diagram\n", - " sankey_idx = (\n", - " [\n", - " (control, test)\n", - " for i in idx\n", - " for control, test in zip(i[:], (i[1:] + (i[0],)))\n", - " ]\n", - " if flow\n", - " else temp_idx\n", - " )\n", - " for i in sankey_idx:\n", - " sankey_control_group.append(i[0])\n", - " sankey_test_group.append(i[1])\n", - "\n", - " if len(temp_all_plot_groups) == 2:\n", - " one_sankey = True\n", - " sankey_control_group.pop()\n", - " sankey_test_group.pop() # Remove the last element from two lists\n", - "\n", - " # two_col_sankey = True if proportional == True and one_sankey == False and sankey == True and flow == False else False\n", - "\n", - " # Replace the paired proportional plot with sankey diagram\n", - " sankeyplot = sankeydiag(\n", - " plot_data,\n", - " xvar=xvar,\n", - " yvar=yvar,\n", - " left_idx=sankey_control_group,\n", - " right_idx=sankey_test_group,\n", - " palette=plot_palette_sankey,\n", - " ax=rawdata_axes,\n", - " one_sankey=one_sankey,\n", - " **sankey_kwargs\n", - " )\n", - "\n", + " sankey_control_group, sankey_test_group = sankeydiag(\n", + " plot_data,\n", + " xvar=xvar,\n", + " yvar=yvar,\n", + " temp_all_plot_groups=temp_all_plot_groups,\n", + " idx=idx,\n", + " temp_idx=temp_idx,\n", + " palette=plot_palette_sankey,\n", + " ax=rawdata_axes,\n", + " **sankey_kwargs\n", + " )\n", " else:\n", " if not proportional:\n", " # Plot the raw data as a swarmplot.\n", " asymmetric_side = (\n", - " plot_kwargs[\"swarm_side\"] if plot_kwargs[\"swarm_side\"] is not None else \"right\"\n", + " plot_kwargs[\"swarm_side\"]\n", + " if plot_kwargs[\"swarm_side\"] is not None\n", + " else \"right\"\n", " ) # Default asymmetric side is right\n", "\n", " # swarmplot() plots swarms based on current size of ax\n", " # Therefore, since the ax size for mini_meta and show_delta changes later on, there has to be increased jitter\n", - " # TODO: to make jitter value more accurate and not just a hardcoded eyeball value\n", - " if show_mini_meta:\n", - " jitter = 1.25\n", - " elif show_delta2:\n", - " jitter = 1.4\n", - " else:\n", - " jitter = 1\n", - "\n", - " if color_col is None: # Determine the use of hue\n", - " rawdata_plot = swarmplot(\n", - " data=plot_data,\n", - " x=xvar,\n", - " y=yvar,\n", - " ax=rawdata_axes,\n", - " order=all_plot_groups,\n", - " hue=xvar,\n", - " palette=plot_palette_raw,\n", - " zorder=1,\n", - " side=asymmetric_side,\n", - " jitter=jitter,\n", - " is_drop_gutter=True,\n", - " gutter_limit=0.45,\n", - " **swarmplot_kwargs\n", - " )\n", + " rawdata_plot, swarm_legend_kwargs = swarmplot(\n", + " data=plot_data,\n", + " x=xvar,\n", + " y=yvar,\n", + " ax=rawdata_axes,\n", + " order=all_plot_groups,\n", + " # hue=xvar if color_col is None else color_col,\n", + " hue=color_col,\n", + " palette=plot_palette_raw,\n", + " zorder=1,\n", + " side=asymmetric_side,\n", + " jitter=1.25 if show_mini_meta else 1.4 if show_delta2 else 1, # TODO: to make jitter value more accurate and not just a hardcoded eyeball value\n", + " filled=filled,\n", + " is_drop_gutter=True,\n", + " gutter_limit=0.45,\n", + " **swarmplot_kwargs\n", + " )\n", + " if color_col is None:\n", " rawdata_plot.legend().set_visible(False)\n", - " else:\n", - " rawdata_plot = swarmplot(\n", - " data=plot_data,\n", - " x=xvar,\n", - " y=yvar,\n", - " ax=rawdata_axes,\n", - " order=all_plot_groups,\n", - " hue=color_col,\n", - " palette=plot_palette_raw,\n", - " zorder=1,\n", - " side=asymmetric_side,\n", - " jitter=jitter,\n", - " is_drop_gutter=True,\n", - " gutter_limit=0.45,\n", - " **swarmplot_kwargs\n", - " )\n", - " else:\n", - " # Plot the raw data as a barplot.\n", - " bar1_df = pd.DataFrame(\n", - " {xvar: all_plot_groups, \"proportion\": np.ones(len(all_plot_groups))}\n", - " )\n", - " bar1 = sns.barplot(\n", - " data=bar1_df,\n", - " x=xvar,\n", - " y=\"proportion\",\n", - " ax=rawdata_axes,\n", - " order=all_plot_groups,\n", - " linewidth=2,\n", - " facecolor=(1, 1, 1, 0),\n", - " edgecolor=bar_color,\n", - " zorder=1,\n", - " )\n", - " bar2 = sns.barplot(\n", - " data=plot_data,\n", - " x=xvar,\n", - " y=yvar,\n", - " ax=rawdata_axes,\n", - " order=all_plot_groups,\n", - " palette=plot_palette_bar,\n", - " zorder=1,\n", - " **barplot_kwargs\n", - " )\n", - " # adjust the width of bars\n", - " bar_width = plot_kwargs[\"bar_width\"]\n", - " for bar in bar1.patches:\n", - " x = bar.get_x()\n", - " width = bar.get_width()\n", - " centre = x + width / 2.0\n", - " bar.set_x(centre - bar_width / 2.0)\n", - " bar.set_width(bar_width)\n", - "\n", - " # Plot the gapped line summaries, if this is not a Cumming plot.\n", - " # Also, we will not plot gapped lines for paired plots. For now.\n", - " group_summaries = plot_kwargs[\"group_summaries\"]\n", - " if group_summaries is None:\n", - " group_summaries = \"mean_sd\"\n", - "\n", - " if group_summaries is not None and not proportional:\n", - " # Create list to gather xspans.\n", - " xspans = []\n", - " line_colors = []\n", - " for jj, c in enumerate(rawdata_axes.collections):\n", - " try:\n", - " if asymmetric_side == \"right\":\n", - " # currently offset is hardcoded with value of -0.2\n", - " x_max_span = -0.2\n", - " else:\n", - " _, x_max, _, _ = get_swarm_spans(c)\n", - " x_max_span = x_max - jj\n", - " xspans.append(x_max_span)\n", - " except TypeError:\n", - " # we have got a None, so skip and move on.\n", - " pass\n", - "\n", - " if bootstraps_color_by_group:\n", - " line_colors.append(plot_palette_raw[all_plot_groups[jj]])\n", "\n", - " # Break the loop since hue in Seaborn adds collections to axes and it will result in index out of range\n", - " if jj >= n_groups - 1 and color_col is None:\n", - " break\n", "\n", - " if len(line_colors) != len(all_plot_groups):\n", - " line_colors = ytick_color\n", - "\n", - " error_bar(\n", - " plot_data,\n", - " x=xvar,\n", - " y=yvar,\n", - " # Hardcoded offset...\n", - " offset=xspans + np.array(plot_kwargs[\"group_summaries_offset\"]),\n", - " line_color=line_colors,\n", - " gap_width_percent=1.5,\n", - " type=group_summaries,\n", - " ax=rawdata_axes,\n", - " method=\"gapped_lines\",\n", - " **group_summary_kwargs\n", - " )\n", + " else:\n", + " # Plot the raw data as a barplot.\n", + " barplotter(\n", + " xvar=xvar, \n", + " yvar=yvar, \n", + " all_plot_groups=all_plot_groups, \n", + " rawdata_axes=rawdata_axes, \n", + " plot_data=plot_data, \n", + " bar_color=bar_color, \n", + " plot_palette_bar=plot_palette_bar, \n", + " plot_kwargs=plot_kwargs, \n", + " barplot_kwargs=barplot_kwargs\n", + " )\n", "\n", - " if group_summaries is not None and proportional:\n", - " err_color = plot_kwargs[\"err_color\"]\n", - " if err_color is None:\n", - " err_color = \"black\"\n", + " # Plot the error bars.\n", + " if group_summaries is not None:\n", + " if proportional:\n", + " group_summaries_method = \"proportional_error_bar\"\n", + " group_summaries_offset = 0\n", + " group_summaries_line_color = err_color\n", + " else:\n", + " # Create list to gather xspans.\n", + " xspans = []\n", + " line_colors = []\n", + " for jj, c in enumerate(rawdata_axes.collections):\n", + " try:\n", + " if asymmetric_side == \"right\":\n", + " # currently offset is hardcoded with value of -0.2\n", + " x_max_span = -0.2\n", + " else:\n", + " _, x_max, _, _ = get_swarm_spans(c)\n", + " x_max_span = x_max - jj\n", + " xspans.append(x_max_span)\n", + " except TypeError:\n", + " # we have got a None, so skip and move on.\n", + " pass\n", + "\n", + " if bootstraps_color_by_group:\n", + " line_colors.append(plot_palette_raw[all_plot_groups[jj]])\n", + "\n", + " # Break the loop since hue in Seaborn adds collections to axes and it will result in index out of range\n", + " if jj >= n_groups - 1 and color_col is None:\n", + " break\n", + "\n", + " if len(line_colors) != len(all_plot_groups):\n", + " line_colors = ytick_color\n", + " \n", + " # hue in swarmplot would add collections to axes which will result in len(xspans) = len(all_plot_groups) + len(unique groups in hue)\n", + " if len(xspans) > len(all_plot_groups):\n", + " xspans = xspans[:len(all_plot_groups)]\n", + "\n", + " group_summaries_method = \"gapped_lines\"\n", + " group_summaries_offset = xspans + np.array(plot_kwargs[\"group_summaries_offset\"])\n", + " group_summaries_line_color = line_colors\n", + "\n", + " # Plot\n", " error_bar(\n", " plot_data,\n", " x=xvar,\n", " y=yvar,\n", - " offset=0,\n", - " line_color=err_color,\n", + " offset=group_summaries_offset,\n", + " line_color=group_summaries_line_color,\n", " gap_width_percent=1.5,\n", " type=group_summaries,\n", " ax=rawdata_axes,\n", - " method=\"proportional_error_bar\",\n", + " method=group_summaries_method,\n", " **group_summary_kwargs\n", - " )\n", + " )\n", "\n", " # Add the counts to the rawdata axes xticks.\n", - " counts = plot_data.groupby(xvar).count()[yvar]\n", - " ticks_with_counts = []\n", - " ticks_loc = rawdata_axes.get_xticks()\n", - " rawdata_axes.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(ticks_loc))\n", - " for xticklab in rawdata_axes.xaxis.get_ticklabels():\n", - " t = xticklab.get_text()\n", - " if t.rfind(\"\\n\") != -1:\n", - " te = t[t.rfind(\"\\n\") + len(\"\\n\") :]\n", - " N = str(counts.loc[te])\n", - " te = t\n", - " else:\n", - " te = t\n", - " N = str(counts.loc[te])\n", - "\n", - " ticks_with_counts.append(\"{}\\nN = {}\".format(te, N))\n", - "\n", - " if plot_kwargs[\"fontsize_rawxlabel\"] is not None:\n", - " fontsize_rawxlabel = plot_kwargs[\"fontsize_rawxlabel\"]\n", - " rawdata_axes.set_xticklabels(ticks_with_counts, fontsize=fontsize_rawxlabel)\n", - "\n", - " # Save the handles and labels for the legend.\n", - " handles, labels = rawdata_axes.get_legend_handles_labels()\n", - " legend_labels = [l for l in labels]\n", - " legend_handles = [h for h in handles]\n", - " if bootstraps_color_by_group is False:\n", - " rawdata_axes.legend().set_visible(False)\n", + " add_counts_to_ticks(\n", + " plot_data=plot_data, \n", + " xvar=xvar, \n", + " yvar=yvar, \n", + " rawdata_axes=rawdata_axes, \n", + " plot_kwargs=plot_kwargs\n", + " )\n", "\n", - " # Enforce the xtick of rawdata_axes to be 0 and 1 after drawing only one sankey\n", + " # Enforce the xtick of rawdata_axes to be 0 and 1 after drawing only one sankey ----> Redundant code\n", " if one_sankey:\n", " rawdata_axes.set_xticks([0, 1])\n", "\n", " # Plot effect sizes and bootstraps.\n", - " # Take note of where the `control` groups are.\n", - " if is_paired == \"baseline\" and show_pairs:\n", - " if two_col_sankey:\n", - " ticks_to_skip = []\n", - " ticks_to_plot = np.arange(0, len(temp_all_plot_groups) / 2).tolist()\n", - " ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist()\n", - " ticks_to_start_twocol_sankey.pop()\n", - " ticks_to_start_twocol_sankey.insert(0, 0)\n", - " else:\n", - " # ticks_to_skip = np.arange(0, len(temp_all_plot_groups), 2).tolist()\n", - " # ticks_to_plot = np.arange(1, len(temp_all_plot_groups), 2).tolist()\n", - " ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()\n", - " ticks_to_skip.insert(0, 0)\n", - " # Then obtain the ticks where we have to plot the effect sizes.\n", - " ticks_to_plot = [\n", - " t for t in range(0, len(all_plot_groups)) if t not in ticks_to_skip\n", - " ]\n", - " ticks_to_skip_contrast = np.cumsum([(len(t)) for t in idx])[:-1].tolist()\n", - " ticks_to_skip_contrast.insert(0, 0)\n", - " else:\n", - " if two_col_sankey:\n", - " ticks_to_skip = [len(sankey_control_group)]\n", - " # Then obtain the ticks where we have to plot the effect sizes.\n", - " ticks_to_plot = [\n", - " t for t in range(0, len(temp_idx)) if t not in ticks_to_skip\n", - " ]\n", - " ticks_to_skip = []\n", - " ticks_to_start_twocol_sankey = np.cumsum([len(i) - 1 for i in idx]).tolist()\n", - " ticks_to_start_twocol_sankey.pop()\n", - " ticks_to_start_twocol_sankey.insert(0, 0)\n", - " else:\n", - " ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()\n", - " ticks_to_skip.insert(0, 0)\n", - " # Then obtain the ticks where we have to plot the effect sizes.\n", - " ticks_to_plot = [\n", - " t for t in range(0, len(all_plot_groups)) if t not in ticks_to_skip\n", - " ]\n", + " plot_groups = temp_all_plot_groups if (is_paired == \"baseline\" and show_pairs and two_col_sankey) else temp_idx if (two_col_sankey) else all_plot_groups\n", + "\n", + " (ticks_to_skip, ticks_to_plot, \n", + " ticks_to_skip_contrast, ticks_to_start_twocol_sankey) = extract_contrast_plotting_ticks(\n", + " is_paired=is_paired, \n", + " show_pairs=show_pairs, \n", + " two_col_sankey=two_col_sankey, \n", + " plot_groups=plot_groups,\n", + " idx=idx,\n", + " sankey_control_group=sankey_control_group if two_col_sankey else None,\n", + " )\n", "\n", " # Plot the bootstraps, then the effect sizes and CIs.\n", " es_marker_size = plot_kwargs[\"es_marker_size\"]\n", " halfviolin_alpha = plot_kwargs[\"halfviolin_alpha\"]\n", - "\n", " ci_type = plot_kwargs[\"ci_type\"]\n", "\n", " results = effectsize_df.results\n", - " contrast_xtick_labels = []\n", - "\n", - " for j, tick in enumerate(ticks_to_plot):\n", - " current_group = results.test[j]\n", - " current_control = results.control[j]\n", - " current_bootstrap = results.bootstraps[j]\n", - " current_effsize = results.difference[j]\n", - " if ci_type == \"bca\":\n", - " current_ci_low = results.bca_low[j]\n", - " current_ci_high = results.bca_high[j]\n", - " else:\n", - " current_ci_low = results.pct_low[j]\n", - " current_ci_high = results.pct_high[j]\n", - "\n", - " # Create the violinplot.\n", - " # New in v0.2.6: drop negative infinities before plotting.\n", - " v = contrast_axes.violinplot(\n", - " current_bootstrap[~np.isinf(current_bootstrap)],\n", - " positions=[tick],\n", - " **violinplot_kwargs\n", - " )\n", - " # Turn the violinplot into half, and color it the same as the swarmplot.\n", - " # Do this only if the color column is not specified.\n", - " # Ideally, the alpha (transparency) fo the violin plot should be\n", - " # less than one so the effect size and CIs are visible.\n", - " if bootstraps_color_by_group:\n", - " fc = plot_palette_contrast[current_group]\n", - " else:\n", - " fc = \"grey\"\n", - "\n", - " halfviolin(v, fill_color=fc, alpha=halfviolin_alpha)\n", - "\n", - " # Plot the effect size.\n", - " contrast_axes.plot(\n", - " [tick],\n", - " current_effsize,\n", - " marker=\"o\",\n", - " color=ytick_color,\n", - " markersize=es_marker_size,\n", - " )\n", "\n", - " ################## SHOW ES ON CONTRAST PLOT WIP\n", - "\n", - " contrast_show_es = plot_kwargs[\"contrast_show_es\"]\n", - " es_sf = plot_kwargs[\"es_sf\"]\n", - " es_fontsize = plot_kwargs[\"es_fontsize\"]\n", - "\n", - " if gridkey_show_es:\n", - " contrast_show_es = False\n", - "\n", - " effsize_for_print = current_effsize\n", - "\n", - " printed_es = np.format_float_positional(\n", - " effsize_for_print, precision=es_sf, sign=True, trim=\"k\", min_digits=es_sf\n", - " )\n", - " if contrast_show_es:\n", - " if effsize_for_print < 0:\n", - " textoffset = 10\n", - " else:\n", - " textoffset = 15\n", - " contrast_axes.annotate(\n", - " text=printed_es,\n", - " xy=(tick, effsize_for_print),\n", - " xytext=(\n", - " -textoffset - len(printed_es) * es_fontsize / 2,\n", - " -es_fontsize / 2,\n", - " ),\n", - " textcoords=\"offset points\",\n", - " **{\"fontsize\": es_fontsize}\n", - " )\n", - "\n", - " ################## SHOW ES ON CONTRAST PLOT END\n", - "\n", - " # Plot the confidence interval.\n", - " contrast_axes.plot(\n", - " [tick, tick],\n", - " [current_ci_low, current_ci_high],\n", - " linestyle=\"-\",\n", - " color=ytick_color,\n", - " linewidth=group_summary_kwargs[\"lw\"],\n", - " )\n", - "\n", - " contrast_xtick_labels.append(\n", - " \"{}\\nminus\\n{}\".format(current_group, current_control)\n", - " )\n", + " (current_group, current_control, \n", + " current_effsize, contrast_xtick_labels) = effect_size_curve_plotter(\n", + " ticks_to_plot=ticks_to_plot, \n", + " results=results, \n", + " ci_type=ci_type, \n", + " contrast_axes=contrast_axes, \n", + " violinplot_kwargs=violinplot_kwargs, \n", + " halfviolin_alpha=halfviolin_alpha, \n", + " ytick_color=ytick_color, \n", + " es_marker_size=es_marker_size, \n", + " group_summary_kwargs=group_summary_kwargs, \n", + " bootstraps_color_by_group=bootstraps_color_by_group,\n", + " plot_palette_contrast=plot_palette_contrast,\n", + " )\n", "\n", " # Plot mini-meta violin\n", " if show_mini_meta or show_delta2:\n", - " if show_mini_meta:\n", - " mini_meta_delta = effectsize_df.mini_meta_delta\n", - " data = mini_meta_delta.bootstraps_weighted_delta\n", - " difference = mini_meta_delta.difference\n", - " if ci_type == \"bca\":\n", - " ci_low = mini_meta_delta.bca_low\n", - " ci_high = mini_meta_delta.bca_high\n", - " else:\n", - " ci_low = mini_meta_delta.pct_low\n", - " ci_high = mini_meta_delta.pct_high\n", - " else:\n", - " delta_delta = effectsize_df.delta_delta\n", - " data = delta_delta.bootstraps_delta_delta\n", - " difference = delta_delta.difference\n", - " if ci_type == \"bca\":\n", - " ci_low = delta_delta.bca_low\n", - " ci_high = delta_delta.bca_high\n", - " else:\n", - " ci_low = delta_delta.pct_low\n", - " ci_high = delta_delta.pct_high\n", - " # Create the violinplot.\n", - " # New in v0.2.6: drop negative infinities before plotting.\n", - " position = max(rawdata_axes.get_xticks()) + 2\n", - " v = contrast_axes.violinplot(\n", - " data[~np.isinf(data)], positions=[position], **violinplot_kwargs\n", - " )\n", - "\n", - " fc = \"grey\"\n", - "\n", - " halfviolin(v, fill_color=fc, alpha=halfviolin_alpha)\n", - "\n", - " # Plot the effect size.\n", - " contrast_axes.plot(\n", - " [position],\n", - " difference,\n", - " marker=\"o\",\n", - " color=ytick_color,\n", - " markersize=es_marker_size,\n", - " )\n", - " # Plot the confidence interval.\n", - " contrast_axes.plot(\n", - " [position, position],\n", - " [ci_low, ci_high],\n", - " linestyle=\"-\",\n", - " color=ytick_color,\n", - " linewidth=group_summary_kwargs[\"lw\"],\n", - " )\n", - " if show_mini_meta:\n", - " contrast_xtick_labels.extend([\"\", \"Weighted delta\"])\n", - " elif effect_size == \"delta_g\":\n", - " contrast_xtick_labels.extend([\"\", \"deltas' g\"])\n", - " else:\n", - " contrast_xtick_labels.extend([\"\", \"delta-delta\"])\n", + " contrast_xtick_labels = plot_minimeta_or_deltadelta_violins(\n", + " show_mini_meta=show_mini_meta, \n", + " effectsize_df=effectsize_df, \n", + " ci_type=ci_type, \n", + " rawdata_axes=rawdata_axes,\n", + " contrast_axes=contrast_axes, \n", + " violinplot_kwargs=violinplot_kwargs, \n", + " halfviolin_alpha=halfviolin_alpha, \n", + " ytick_color=ytick_color, \n", + " es_marker_size=es_marker_size, \n", + " group_summary_kwargs=group_summary_kwargs, \n", + " contrast_xtick_labels=contrast_xtick_labels, \n", + " effect_size=effect_size\n", + " )\n", "\n", " # Make sure the contrast_axes x-lims match the rawdata_axes xlims,\n", " # and add an extra violinplot tick for delta-delta plot.\n", - " if show_delta2 is False and show_mini_meta is False:\n", - " contrast_axes.set_xticks(rawdata_axes.get_xticks())\n", - " else:\n", - " temp = rawdata_axes.get_xticks()\n", - " temp = np.append(temp, [max(temp) + 1, max(temp) + 2])\n", - " contrast_axes.set_xticks(temp)\n", - "\n", - " if show_pairs:\n", - " max_x = contrast_axes.get_xlim()[1]\n", - " rawdata_axes.set_xlim(-0.375, max_x)\n", - "\n", - " if float_contrast:\n", - " contrast_axes.set_xlim(0.5, 1.5)\n", - " elif show_delta2 or show_mini_meta:\n", - " # Increase the xlim of raw data by 2\n", - " temp = rawdata_axes.get_xlim()\n", - " if show_pairs:\n", - " rawdata_axes.set_xlim(temp[0], temp[1] + 0.25)\n", - " else:\n", - " rawdata_axes.set_xlim(temp[0], temp[1] + 2)\n", - " contrast_axes.set_xlim(rawdata_axes.get_xlim())\n", - " else:\n", - " contrast_axes.set_xlim(rawdata_axes.get_xlim())\n", - "\n", - " # Properly label the contrast ticks.\n", - " for t in ticks_to_skip:\n", - " contrast_xtick_labels.insert(t, \"\")\n", - "\n", - " if plot_kwargs[\"fontsize_contrastxlabel\"] is not None:\n", - " fontsize_contrastxlabel = plot_kwargs[\"fontsize_contrastxlabel\"]\n", - "\n", - " contrast_axes.set_xticklabels(\n", - " contrast_xtick_labels, fontsize=fontsize_contrastxlabel\n", - " )\n", + " set_xaxis_ticks_and_lims(\n", + " show_delta2=show_delta2, \n", + " show_mini_meta=show_mini_meta, \n", + " rawdata_axes=rawdata_axes, \n", + " contrast_axes=contrast_axes, \n", + " show_pairs=show_pairs, \n", + " float_contrast=float_contrast,\n", + " ticks_to_skip=ticks_to_skip, \n", + " contrast_xtick_labels=contrast_xtick_labels, \n", + " plot_kwargs=plot_kwargs,\n", + " )\n", + " # Legend\n", + " handles, labels = rawdata_axes.get_legend_handles_labels()\n", + " legend_labels = [l for l in labels]\n", + " legend_handles = [h for h in handles]\n", "\n", " if bootstraps_color_by_group is False:\n", - " legend_labels_unique = np.unique(legend_labels)\n", - " unique_idx = np.unique(legend_labels, return_index=True)[1]\n", - " legend_handles_unique = (\n", - " pd.Series(legend_handles, dtype=\"object\").loc[unique_idx]\n", - " ).tolist()\n", - "\n", - " if len(legend_handles_unique) > 0:\n", - " if float_contrast:\n", - " axes_with_legend = contrast_axes\n", - " if show_pairs:\n", - " bta = (1.75, 1.02)\n", - " else:\n", - " bta = (1.5, 1.02)\n", - " else:\n", - " axes_with_legend = rawdata_axes\n", - " if show_pairs:\n", - " bta = (1.02, 1.0)\n", - " else:\n", - " bta = (1.0, 1.0)\n", - " leg = axes_with_legend.legend(\n", - " legend_handles_unique,\n", - " legend_labels_unique,\n", - " bbox_to_anchor=bta,\n", - " **legend_kwargs\n", + " rawdata_axes.legend().set_visible(False)\n", + " show_legend(\n", + " legend_labels=legend_labels, \n", + " legend_handles=legend_handles, \n", + " rawdata_axes=rawdata_axes, \n", + " contrast_axes=contrast_axes, \n", + " float_contrast=float_contrast, \n", + " show_pairs=show_pairs, \n", + " legend_kwargs=legend_kwargs\n", " )\n", - " if show_pairs:\n", - " for line in leg.get_lines():\n", - " line.set_linewidth(3.0)\n", "\n", + " # Add legend for swarmplot\n", + " if not show_pairs and not proportional and color_col is not None and not show_delta2:\n", + " if len(np.unique(swarm_legend_kwargs['index'])) > 1:\n", + " legend_elements = []\n", + " for color, label in zip(swarm_legend_kwargs['colors'], swarm_legend_kwargs['labels']):\n", + " legend_elements.append(Line2D([0], [0], marker='o', color='w', label=label,\n", + " markerfacecolor=color, markersize=10))\n", + " rawdata_axes.legend(handles=legend_elements, frameon=False)\n", + "\n", + " # Plot aesthetic adjustments.\n", " og_ylim_raw = rawdata_axes.get_ylim()\n", " og_xlim_raw = rawdata_axes.get_xlim()\n", "\n", " if float_contrast:\n", " # For Gardner-Altman plots only.\n", - "\n", - " # Normalize ylims and despine the floating contrast axes.\n", - " # Check that the effect size is within the swarm ylims.\n", - " if effect_size_type in [\"mean_diff\", \"cohens_d\", \"hedges_g\", \"cohens_h\"]:\n", - " control_group_summary = (\n", - " plot_data.groupby(xvar)\n", - " .mean(numeric_only=True)\n", - " .loc[current_control, yvar]\n", - " )\n", - " test_group_summary = (\n", - " plot_data.groupby(xvar).mean(numeric_only=True).loc[current_group, yvar]\n", - " )\n", - " elif effect_size_type == \"median_diff\":\n", - " control_group_summary = (\n", - " plot_data.groupby(xvar).median().loc[current_control, yvar]\n", - " )\n", - " test_group_summary = (\n", - " plot_data.groupby(xvar).median().loc[current_group, yvar]\n", - " )\n", - "\n", - " if swarm_ylim is None:\n", - " swarm_ylim = rawdata_axes.get_ylim()\n", - "\n", - " _, contrast_xlim_max = contrast_axes.get_xlim()\n", - "\n", - " difference = float(results.difference[0])\n", - "\n", - " if effect_size_type in [\"mean_diff\", \"median_diff\"]:\n", - " # Align 0 of contrast_axes to reference group mean of rawdata_axes.\n", - " # If the effect size is positive, shift the contrast axis up.\n", - " rawdata_ylims = np.array(rawdata_axes.get_ylim())\n", - " if current_effsize > 0:\n", - " rightmin, rightmax = rawdata_ylims - current_effsize\n", - " # If the effect size is negative, shift the contrast axis down.\n", - " elif current_effsize < 0:\n", - " rightmin, rightmax = rawdata_ylims + current_effsize\n", - " else:\n", - " rightmin, rightmax = rawdata_ylims\n", - "\n", - " contrast_axes.set_ylim(rightmin, rightmax)\n", - "\n", - " og_ylim_contrast = rawdata_axes.get_ylim() - np.array(control_group_summary)\n", - "\n", - " contrast_axes.set_ylim(og_ylim_contrast)\n", - " contrast_axes.set_xlim(contrast_xlim_max - 1, contrast_xlim_max)\n", - "\n", - " elif effect_size_type in [\"cohens_d\", \"hedges_g\", \"cohens_h\"]:\n", - " if is_paired:\n", - " which_std = 1\n", - " else:\n", - " which_std = 0\n", - " temp_control = plot_data[plot_data[xvar] == current_control][yvar]\n", - " temp_test = plot_data[plot_data[xvar] == current_group][yvar]\n", - "\n", - " stds = _compute_standardizers(temp_control, temp_test)\n", - " if is_paired:\n", - " pooled_sd = stds[1]\n", - " else:\n", - " pooled_sd = stds[0]\n", - "\n", - " if effect_size_type == \"hedges_g\":\n", - " gby_count = plot_data.groupby(xvar).count()\n", - " len_control = gby_count.loc[current_control, yvar]\n", - " len_test = gby_count.loc[current_group, yvar]\n", - "\n", - " hg_correction_factor = _compute_hedges_correction_factor(\n", - " len_control, len_test\n", - " )\n", - "\n", - " ylim_scale_factor = pooled_sd / hg_correction_factor\n", - "\n", - " elif effect_size_type == \"cohens_h\":\n", - " ylim_scale_factor = (\n", - " np.mean(temp_test) - np.mean(temp_control)\n", - " ) / difference\n", - "\n", - " else:\n", - " ylim_scale_factor = pooled_sd\n", - "\n", - " scaled_ylim = (\n", - " (rawdata_axes.get_ylim() - control_group_summary) / ylim_scale_factor\n", - " ).tolist()\n", - "\n", - " contrast_axes.set_ylim(scaled_ylim)\n", - " og_ylim_contrast = scaled_ylim\n", - "\n", - " contrast_axes.set_xlim(contrast_xlim_max - 1, contrast_xlim_max)\n", - "\n", - " if one_sankey is None:\n", - " # Draw summary lines for control and test groups..\n", - " for jj, axx in enumerate([rawdata_axes, contrast_axes]):\n", - " # Draw effect size line.\n", - " if jj == 0:\n", - " ref = control_group_summary\n", - " diff = test_group_summary\n", - " effsize_line_start = 1\n", - "\n", - " elif jj == 1:\n", - " ref = 0\n", - " diff = ref + difference\n", - " effsize_line_start = contrast_xlim_max - 1.1\n", - "\n", - " xlimlow, xlimhigh = axx.get_xlim()\n", - "\n", - " # Draw reference line.\n", - " axx.hlines(\n", - " ref, # y-coordinates\n", - " 0,\n", - " xlimhigh, # x-coordinates, start and end.\n", - " **reflines_kwargs\n", - " )\n", - "\n", - " # Draw effect size line.\n", - " axx.hlines(diff, effsize_line_start, xlimhigh, **reflines_kwargs)\n", - " else:\n", - " ref = 0\n", - " diff = ref + difference\n", - " effsize_line_start = contrast_xlim_max - 0.9\n", - " xlimlow, xlimhigh = contrast_axes.get_xlim()\n", - " # Draw reference line.\n", - " contrast_axes.hlines(\n", - " ref, # y-coordinates\n", - " effsize_line_start,\n", - " xlimhigh, # x-coordinates, start and end.\n", - " **reflines_kwargs\n", - " )\n", - "\n", - " # Draw effect size line.\n", - " contrast_axes.hlines(diff, effsize_line_start, xlimhigh, **reflines_kwargs)\n", - " rawdata_axes.set_xlim(og_xlim_raw) # to align the axis\n", - " # Despine appropriately.\n", - " sns.despine(ax=rawdata_axes, bottom=True)\n", - " sns.despine(ax=contrast_axes, left=True, right=False)\n", - "\n", - " # Insert break between the rawdata axes and the contrast axes\n", - " # by re-drawing the x-spine.\n", - " rawdata_axes.hlines(\n", - " og_ylim_raw[0], # yindex\n", - " rawdata_axes.get_xlim()[0],\n", - " 1.3, # xmin, xmax\n", - " **redraw_axes_kwargs\n", - " )\n", - " rawdata_axes.set_ylim(og_ylim_raw)\n", - "\n", - " contrast_axes.hlines(\n", - " contrast_axes.get_ylim()[0],\n", - " contrast_xlim_max - 0.8,\n", - " contrast_xlim_max,\n", - " **redraw_axes_kwargs\n", - " )\n", + " Gardner_Altman_Plot_Aesthetic_Adjustments(\n", + " effect_size_type=effect_size_type, \n", + " plot_data=plot_data, \n", + " xvar=xvar, \n", + " yvar=yvar, \n", + " current_control=current_control, \n", + " current_group=current_group,\n", + " rawdata_axes=rawdata_axes, \n", + " contrast_axes=contrast_axes, \n", + " results=results, \n", + " current_effsize=current_effsize, \n", + " is_paired=is_paired, \n", + " one_sankey=one_sankey,\n", + " reflines_kwargs=reflines_kwargs, \n", + " redraw_axes_kwargs=redraw_axes_kwargs, \n", + " swarm_ylim=swarm_ylim, \n", + " og_xlim_raw=og_xlim_raw,\n", + " og_ylim_raw=og_ylim_raw,\n", + " )\n", "\n", " else:\n", " # For Cumming Plots only.\n", - "\n", - " # Set custom contrast_ylim, if it was specified.\n", - " if plot_kwargs[\"contrast_ylim\"] is not None or (\n", - " plot_kwargs[\"delta2_ylim\"] is not None and show_delta2\n", - " ):\n", - " if plot_kwargs[\"contrast_ylim\"] is not None:\n", - " custom_contrast_ylim = plot_kwargs[\"contrast_ylim\"]\n", - " if plot_kwargs[\"delta2_ylim\"] is not None and show_delta2:\n", - " custom_delta2_ylim = plot_kwargs[\"delta2_ylim\"]\n", - " if custom_contrast_ylim != custom_delta2_ylim:\n", - " err1 = \"Please check if `contrast_ylim` and `delta2_ylim` are assigned\"\n", - " err2 = \"with same values.\"\n", - " raise ValueError(err1 + err2)\n", - " else:\n", - " custom_delta2_ylim = plot_kwargs[\"delta2_ylim\"]\n", - " custom_contrast_ylim = custom_delta2_ylim\n", - "\n", - " if len(custom_contrast_ylim) != 2:\n", - " err1 = \"Please check `contrast_ylim` consists of \"\n", - " err2 = \"exactly two numbers.\"\n", - " raise ValueError(err1 + err2)\n", - "\n", - " if effect_size_type == \"cliffs_delta\":\n", - " # Ensure the ylims for a cliffs_delta plot never exceed [-1, 1].\n", - " l = plot_kwargs[\"contrast_ylim\"][0]\n", - " h = plot_kwargs[\"contrast_ylim\"][1]\n", - " low = -1 if l < -1 else l\n", - " high = 1 if h > 1 else h\n", - " contrast_axes.set_ylim(low, high)\n", - " else:\n", - " contrast_axes.set_ylim(custom_contrast_ylim)\n", - "\n", - " # If 0 lies within the ylim of the contrast axes,\n", - " # draw a zero reference line.\n", - " contrast_axes_ylim = contrast_axes.get_ylim()\n", - " if contrast_axes_ylim[0] < contrast_axes_ylim[1]:\n", - " contrast_ylim_low, contrast_ylim_high = contrast_axes_ylim\n", - " else:\n", - " contrast_ylim_high, contrast_ylim_low = contrast_axes_ylim\n", - " if contrast_ylim_low < 0 < contrast_ylim_high:\n", - " contrast_axes.axhline(y=0, **reflines_kwargs)\n", - "\n", - " if is_paired == \"baseline\" and show_pairs:\n", - " if two_col_sankey:\n", - " rightend_ticks_raw = np.array([len(i) - 2 for i in idx]) + np.array(\n", - " ticks_to_start_twocol_sankey\n", - " )\n", - " elif proportional and is_paired is not None:\n", - " rightend_ticks_raw = np.array([len(i) - 1 for i in idx]) + np.array(\n", - " ticks_to_skip\n", - " )\n", - " else:\n", - " rightend_ticks_raw = np.array(\n", - " [len(i) - 1 for i in temp_idx]\n", - " ) + np.array(ticks_to_skip)\n", - " for ax in [rawdata_axes]:\n", - " sns.despine(ax=ax, bottom=True)\n", - "\n", - " ylim = ax.get_ylim()\n", - " xlim = ax.get_xlim()\n", - " redraw_axes_kwargs[\"y\"] = ylim[0]\n", - "\n", - " if two_col_sankey:\n", - " for k, start_tick in enumerate(ticks_to_start_twocol_sankey):\n", - " end_tick = rightend_ticks_raw[k]\n", - " ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs)\n", - " else:\n", - " for k, start_tick in enumerate(ticks_to_skip):\n", - " end_tick = rightend_ticks_raw[k]\n", - " ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs)\n", - " ax.set_ylim(ylim)\n", - " del redraw_axes_kwargs[\"y\"]\n", - "\n", - " if not proportional:\n", - " temp_length = [(len(i) - 1) for i in idx]\n", - " else:\n", - " temp_length = [(len(i) - 1) * 2 - 1 for i in idx]\n", - " if two_col_sankey:\n", - " rightend_ticks_contrast = np.array(\n", - " [len(i) - 2 for i in idx]\n", - " ) + np.array(ticks_to_start_twocol_sankey)\n", - " elif proportional and is_paired is not None:\n", - " rightend_ticks_contrast = np.array(\n", - " [len(i) - 1 for i in idx]\n", - " ) + np.array(ticks_to_skip)\n", - " else:\n", - " rightend_ticks_contrast = np.array(temp_length) + np.array(\n", - " ticks_to_skip_contrast\n", - " )\n", - " for ax in [contrast_axes]:\n", - " sns.despine(ax=ax, bottom=True)\n", - "\n", - " ylim = ax.get_ylim()\n", - " xlim = ax.get_xlim()\n", - " redraw_axes_kwargs[\"y\"] = ylim[0]\n", - "\n", - " if two_col_sankey:\n", - " for k, start_tick in enumerate(ticks_to_start_twocol_sankey):\n", - " end_tick = rightend_ticks_contrast[k]\n", - " ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs)\n", - " else:\n", - " for k, start_tick in enumerate(ticks_to_skip_contrast):\n", - " end_tick = rightend_ticks_contrast[k]\n", - " ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs)\n", - "\n", - " ax.set_ylim(ylim)\n", - " del redraw_axes_kwargs[\"y\"]\n", - " else:\n", - " # Compute the end of each x-axes line.\n", - " if two_col_sankey:\n", - " rightend_ticks = np.array([len(i) - 2 for i in idx]) + np.array(\n", - " ticks_to_start_twocol_sankey\n", - " )\n", - " else:\n", - " rightend_ticks = np.array([len(i) - 1 for i in idx]) + np.array(\n", - " ticks_to_skip\n", - " )\n", - "\n", - " for ax in [rawdata_axes, contrast_axes]:\n", - " sns.despine(ax=ax, bottom=True)\n", - "\n", - " ylim = ax.get_ylim()\n", - " xlim = ax.get_xlim()\n", - " redraw_axes_kwargs[\"y\"] = ylim[0]\n", - "\n", - " if two_col_sankey:\n", - " for k, start_tick in enumerate(ticks_to_start_twocol_sankey):\n", - " end_tick = rightend_ticks[k]\n", - " ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs)\n", - " else:\n", - " for k, start_tick in enumerate(ticks_to_skip):\n", - " end_tick = rightend_ticks[k]\n", - " ax.hlines(xmin=start_tick, xmax=end_tick, **redraw_axes_kwargs)\n", - "\n", - " ax.set_ylim(ylim)\n", - " del redraw_axes_kwargs[\"y\"]\n", - "\n", - " if show_delta2 or show_mini_meta:\n", - " ylim = contrast_axes.get_ylim()\n", - " redraw_axes_kwargs[\"y\"] = ylim[0]\n", - " x_ticks = contrast_axes.get_xticks()\n", - " contrast_axes.hlines(xmin=x_ticks[-2], xmax=x_ticks[-1], **redraw_axes_kwargs)\n", - " del redraw_axes_kwargs[\"y\"]\n", - "\n", - " # Set raw axes y-label.\n", - " swarm_label = plot_kwargs[\"swarm_label\"]\n", - " if swarm_label is None and yvar is None:\n", - " swarm_label = \"value\"\n", - " elif swarm_label is None and yvar is not None:\n", - " swarm_label = yvar\n", - "\n", - " bar_label = plot_kwargs[\"bar_label\"]\n", - " if bar_label is None and effect_size_type != \"cohens_h\":\n", - " bar_label = \"proportion of success\"\n", - " elif bar_label is None and effect_size_type == \"cohens_h\":\n", - " bar_label = \"value\"\n", - "\n", - " # Place contrast axes y-label.\n", - " contrast_label_dict = {\n", - " \"mean_diff\": \"mean difference\",\n", - " \"median_diff\": \"median difference\",\n", - " \"cohens_d\": \"Cohen's d\",\n", - " \"hedges_g\": \"Hedges' g\",\n", - " \"cliffs_delta\": \"Cliff's delta\",\n", - " \"cohens_h\": \"Cohen's h\",\n", - " \"delta_g\": \"mean difference\",\n", - " }\n", - "\n", - " if proportional and effect_size_type != \"cohens_h\":\n", - " default_contrast_label = \"proportion difference\"\n", - " elif effect_size_type == \"delta_g\":\n", - " default_contrast_label = \"Hedges' g\"\n", - " else:\n", - " default_contrast_label = contrast_label_dict[effectsize_df.effect_size]\n", - "\n", - " if plot_kwargs[\"contrast_label\"] is None:\n", - " if is_paired:\n", - " contrast_label = \"paired\\n{}\".format(default_contrast_label)\n", - " else:\n", - " contrast_label = default_contrast_label\n", - " contrast_label = contrast_label.capitalize()\n", - " else:\n", - " contrast_label = plot_kwargs[\"contrast_label\"]\n", - "\n", - " if plot_kwargs[\"fontsize_rawylabel\"] is not None:\n", - " fontsize_rawylabel = plot_kwargs[\"fontsize_rawylabel\"]\n", - " if plot_kwargs[\"fontsize_contrastylabel\"] is not None:\n", - " fontsize_contrastylabel = plot_kwargs[\"fontsize_contrastylabel\"]\n", - " if plot_kwargs[\"fontsize_delta2label\"] is not None:\n", - " fontsize_delta2label = plot_kwargs[\"fontsize_delta2label\"]\n", - "\n", - " contrast_axes.set_ylabel(contrast_label, fontsize=fontsize_contrastylabel)\n", - " if float_contrast:\n", - " contrast_axes.yaxis.set_label_position(\"right\")\n", - "\n", - " # Set the rawdata axes labels appropriately\n", - " if not proportional:\n", - " rawdata_axes.set_ylabel(swarm_label, fontsize=fontsize_rawylabel)\n", - " else:\n", - " rawdata_axes.set_ylabel(bar_label, fontsize=fontsize_rawylabel)\n", - " rawdata_axes.set_xlabel(\"\")\n", - "\n", - " # Because we turned the axes frame off, we also need to draw back\n", - " # the y-spine for both axes.\n", - " if not float_contrast:\n", - " rawdata_axes.set_xlim(contrast_axes.get_xlim())\n", - " og_xlim_raw = rawdata_axes.get_xlim()\n", - " rawdata_axes.vlines(\n", - " og_xlim_raw[0], og_ylim_raw[0], og_ylim_raw[1], **redraw_axes_kwargs\n", - " )\n", - "\n", - " og_xlim_contrast = contrast_axes.get_xlim()\n", - "\n", - " if float_contrast:\n", - " xpos = og_xlim_contrast[1]\n", - " else:\n", - " xpos = og_xlim_contrast[0]\n", - "\n", - " og_ylim_contrast = contrast_axes.get_ylim()\n", - " contrast_axes.vlines(\n", - " xpos, og_ylim_contrast[0], og_ylim_contrast[1], **redraw_axes_kwargs\n", - " )\n", - "\n", - " if show_delta2:\n", - " if plot_kwargs[\"delta2_label\"] is not None:\n", - " delta2_label = plot_kwargs[\"delta2_label\"]\n", - " elif effect_size == \"mean_diff\":\n", - " delta2_label = \"delta - delta\"\n", - " else:\n", - " delta2_label = \"deltas' g\"\n", - " delta2_axes = contrast_axes.twinx()\n", - " delta2_axes.set_frame_on(False)\n", - " delta2_axes.set_ylabel(delta2_label, fontsize=fontsize_delta2label)\n", - " og_xlim_delta = contrast_axes.get_xlim()\n", - " og_ylim_delta = contrast_axes.get_ylim()\n", - " delta2_axes.set_ylim(og_ylim_delta)\n", - " delta2_axes.vlines(\n", - " og_xlim_delta[1], og_ylim_delta[0], og_ylim_delta[1], **redraw_axes_kwargs\n", - " )\n", - "\n", - " ################################################### GRIDKEY MAIN CODE WIP\n", - "\n", + " Cumming_Plot_Aesthetic_Adjustments(\n", + " plot_kwargs=plot_kwargs, \n", + " show_delta2=show_delta2, \n", + " effect_size_type=effect_size_type, \n", + " contrast_axes=contrast_axes, \n", + " reflines_kwargs=reflines_kwargs, \n", + " is_paired=is_paired, \n", + " show_pairs=show_pairs, \n", + " two_col_sankey=two_col_sankey, \n", + " idx=idx, \n", + " ticks_to_start_twocol_sankey=ticks_to_start_twocol_sankey,\n", + " proportional=proportional, \n", + " ticks_to_skip=ticks_to_skip, \n", + " temp_idx=temp_idx if is_paired == \"baseline\" and show_pairs else None, \n", + " rawdata_axes=rawdata_axes, \n", + " redraw_axes_kwargs=redraw_axes_kwargs,\n", + " ticks_to_skip_contrast=ticks_to_skip_contrast,\n", + " )\n", + " \n", + " # General plotting changes\n", + " General_Plot_Aesthetic_Adjustments(\n", + " show_delta2=show_delta2, \n", + " show_mini_meta=show_mini_meta, \n", + " contrast_axes=contrast_axes, \n", + " redraw_axes_kwargs=redraw_axes_kwargs, \n", + " plot_kwargs=plot_kwargs,\n", + " yvar=yvar, \n", + " effect_size_type=effect_size_type, \n", + " proportional=proportional, \n", + " effectsize_df=effectsize_df, \n", + " is_paired=is_paired, \n", + " float_contrast=float_contrast,\n", + " rawdata_axes=rawdata_axes, \n", + " og_ylim_raw=og_ylim_raw, \n", + " effect_size=effect_size,\n", + " )\n", + "\n", + " ################################################### GRIDKEY WIP\n", " # if gridkey_rows is None, skip everything here\n", + " gridkey_rows = plot_kwargs[\"gridkey_rows\"]\n", " if gridkey_rows is not None:\n", - " # Raise error if there are more than 2 items in any idx and gridkey_merge_pairs is True and is_paired is not None\n", - " if gridkey_merge_pairs and is_paired is not None:\n", - " for i in idx:\n", - " if len(i) > 2:\n", - " warnings.warn(\n", - " \"gridkey_merge_pairs=True only works if all idx in tuples have only two items. gridkey_merge_pairs has automatically been set to False\"\n", - " )\n", - " gridkey_merge_pairs = False\n", - " break\n", - " elif gridkey_merge_pairs and is_paired is None:\n", - " warnings.warn(\n", - " \"gridkey_merge_pairs=True is only applicable for paired data.\"\n", - " )\n", - " gridkey_merge_pairs = False\n", - "\n", - " # Checks for gridkey_merge_pairs and is_paired; if both are true, \"merges\" the gridkey per pair\n", - " if gridkey_merge_pairs and is_paired is not None:\n", - " groups_for_gridkey = []\n", - " for i in idx:\n", - " groups_for_gridkey.append(i[1])\n", - " else:\n", - " groups_for_gridkey = all_plot_groups\n", - "\n", - " # raise errors if gridkey_rows is not a list, or if the list is empty\n", - " if isinstance(gridkey_rows, list) is False:\n", - " raise TypeError(\"gridkey_rows must be a list.\")\n", - " elif len(gridkey_rows) == 0:\n", - " warnings.warn(\"gridkey_rows is an empty list.\")\n", + " grid_key_WIP(\n", + " is_paired=is_paired, \n", + " idx=idx, \n", + " all_plot_groups=all_plot_groups, \n", + " gridkey_rows=gridkey_rows, \n", + " rawdata_axes=rawdata_axes, \n", + " contrast_axes=contrast_axes,\n", + " plot_data=plot_data, \n", + " xvar=xvar, \n", + " yvar=yvar, \n", + " results=results, \n", + " show_delta2=show_delta2, \n", + " show_mini_meta=show_mini_meta, \n", + " float_contrast=float_contrast,\n", + " plot_kwargs=plot_kwargs,\n", + " )\n", "\n", - " # raise Warning if an item in gridkey_rows is not contained in any idx\n", - " for i in gridkey_rows:\n", - " in_idx = 0\n", - " for j in groups_for_gridkey:\n", - " if i in j:\n", - " in_idx += 1\n", - " if in_idx == 0:\n", - " if is_paired is not None:\n", - " warnings.warn(\n", - " i\n", - " + \" is not in any idx. Please check. Alternatively, merging gridkey pairs may not be suitable for your data; try passing gridkey_merge_pairs=False.\"\n", + " ################################################### Swarm & Contrast & Summary Bars & Delta text WIP\n", + " # Swarm bars WIP\n", + " swarm_bars = plot_kwargs[\"swarm_bars\"]\n", + " if swarm_bars and not proportional:\n", + " swarm_bars_plotter(\n", + " plot_data=plot_data, \n", + " xvar=xvar, \n", + " yvar=yvar, \n", + " ax=rawdata_axes, \n", + " swarm_bars_kwargs=swarm_bars_kwargs, \n", + " color_col=color_col, \n", + " plot_palette_raw=plot_palette_raw,\n", + " is_paired=is_paired\n", " )\n", - " else:\n", - " warnings.warn(i + \" is not in any idx. Please check.\")\n", - "\n", - " # Populate table: checks if idx for each column contains rowlabel name\n", - " # IF so, marks that element as present w black dot, or space if not present\n", - " table_cellcols = []\n", - " for i in gridkey_rows:\n", - " thisrow = []\n", - " for q in groups_for_gridkey:\n", - " if str(i) in q:\n", - " thisrow.append(\"\\u25CF\")\n", - " else:\n", - " thisrow.append(\"\")\n", - " table_cellcols.append(thisrow)\n", - "\n", - " # Adds a row for Ns with the Ns values\n", - " if gridkey_show_Ns:\n", - " gridkey_rows.append(\"Ns\")\n", - " list_of_Ns = []\n", - " for i in groups_for_gridkey:\n", - " list_of_Ns.append(str(counts.loc[i]))\n", - " table_cellcols.append(list_of_Ns)\n", "\n", - " # Adds a row for effectsizes with effectsize values\n", - " if gridkey_show_es:\n", - " gridkey_rows.append(\"\\u0394\")\n", - " effsize_list = []\n", - " results_list = results.test.to_list()\n", - "\n", - " # get the effect size, append + or -, 2 dec places\n", - " for i in enumerate(groups_for_gridkey):\n", - " if i[1] in results_list:\n", - " curr_esval = results.loc[results[\"test\"] == i[1]][\n", - " \"difference\"\n", - " ].iloc[0]\n", - " curr_esval_str = np.format_float_positional(\n", - " curr_esval,\n", - " precision=es_sf,\n", - " sign=True,\n", - " trim=\"k\",\n", - " min_digits=es_sf,\n", + " # Contrast bars WIP\n", + " contrast_bars = plot_kwargs[\"contrast_bars\"]\n", + " if contrast_bars:\n", + " contrast_bars_plotter(\n", + " results=results, \n", + " ax_to_plot=contrast_axes, \n", + " swarm_plot_ax=rawdata_axes,\n", + " ticks_to_plot=ticks_to_plot, \n", + " contrast_bars_kwargs=contrast_bars_kwargs, \n", + " color_col=color_col, \n", + " plot_palette_raw=plot_palette_raw, \n", + " show_mini_meta=show_mini_meta, \n", + " mini_meta_delta=effectsize_df.mini_meta_delta if show_mini_meta else None, \n", + " show_delta2=show_delta2, \n", + " delta_delta=effectsize_df.delta_delta if show_delta2 else None, \n", + " proportional=proportional, \n", + " is_paired=is_paired\n", + " )\n", + "\n", + " # Summary bars WIP\n", + " summary_bars = plot_kwargs[\"summary_bars\"]\n", + " if summary_bars is not None:\n", + " summary_bars_plotter(\n", + " summary_bars=summary_bars, \n", + " results=results, \n", + " ax_to_plot=contrast_axes, \n", + " float_contrast=float_contrast,\n", + " summary_bars_kwargs=summary_bars_kwargs, \n", + " ci_type=ci_type, \n", + " ticks_to_plot=ticks_to_plot, \n", + " color_col=color_col,\n", + " swarm_colors=swarm_colors, \n", + " proportional=proportional, \n", + " is_paired=is_paired\n", + " )\n", + " # Delta text WIP\n", + " delta_text = plot_kwargs[\"delta_text\"]\n", + " if delta_text: \n", + " delta_text_plotter(\n", + " results=results, \n", + " ax_to_plot=contrast_axes, \n", + " swarm_plot_ax=rawdata_axes, \n", + " ticks_to_plot=ticks_to_plot, \n", + " delta_text_kwargs=delta_text_kwargs, \n", + " color_col=color_col, \n", + " swarm_colors=swarm_colors, \n", + " is_paired=is_paired,\n", + " proportional=proportional, \n", + " float_contrast=float_contrast, \n", + " show_mini_meta=show_mini_meta, \n", + " mini_meta_delta=effectsize_df.mini_meta_delta if show_mini_meta else None, \n", + " show_delta2=show_delta2, \n", + " delta_delta=effectsize_df.delta_delta if show_delta2 else None\n", " )\n", - " effsize_list.append(curr_esval_str)\n", - " else:\n", - " effsize_list.append(\"-\")\n", - "\n", - " table_cellcols.append(effsize_list)\n", - "\n", - " # If Gardner-Altman plot, plot on raw data and not contrast axes\n", - " if float_contrast:\n", - " axes_ploton = rawdata_axes\n", - " else:\n", - " axes_ploton = contrast_axes\n", - "\n", - " # Account for extended x axis in case of show_delta2 or show_mini_meta\n", - " x_groups_for_width = len(groups_for_gridkey)\n", - " if show_delta2 or show_mini_meta:\n", - " x_groups_for_width += 2\n", - " gridkey_width = len(groups_for_gridkey) / x_groups_for_width\n", - "\n", - " gridkey = axes_ploton.table(\n", - " cellText=table_cellcols,\n", - " rowLabels=gridkey_rows,\n", - " cellLoc=\"center\",\n", - " bbox=[\n", - " 0,\n", - " -len(gridkey_rows) * 0.1 - 0.05,\n", - " gridkey_width,\n", - " len(gridkey_rows) * 0.1,\n", - " ],\n", - " **{\"alpha\": 0.5}\n", - " )\n", - "\n", - " # modifies row label cells\n", - " for cell in gridkey._cells:\n", - " if cell[1] == -1:\n", - " gridkey._cells[cell].visible_edges = \"open\"\n", - " gridkey._cells[cell].set_text_props(**{\"ha\": \"right\"})\n", - "\n", - " # turns off both x axes\n", - " rawdata_axes.get_xaxis().set_visible(False)\n", - " contrast_axes.get_xaxis().set_visible(False)\n", - "\n", - " ####################################################### END GRIDKEY MAIN CODE WIP\n", + " ################################################### Swarm & Contrast & Summary Bars & Delta text WIP END\n", "\n", " # Make sure no stray ticks appear!\n", " rawdata_axes.xaxis.set_ticks_position(\"bottom\")\n", @@ -1661,16 +646,9 @@ " plt.rcParams[parameter] = original_rcParams[parameter]\n", "\n", " # Return the figure.\n", - " return fig\n" + " fig.show()\n", + " return fig" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7355251f", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_01_gardner_altman_unpaired_meandiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_01_gardner_altman_unpaired_meandiff.png index e45d3b83..d6374405 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_01_gardner_altman_unpaired_meandiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_01_gardner_altman_unpaired_meandiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_02_gardner_altman_unpaired_mediandiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_02_gardner_altman_unpaired_mediandiff.png index de5d07ef..7dc4d313 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_02_gardner_altman_unpaired_mediandiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_02_gardner_altman_unpaired_mediandiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_03_gardner_altman_unpaired_hedges_g.png b/nbs/tests/mpl_image_tests/baseline_images/test_03_gardner_altman_unpaired_hedges_g.png index 80d36fcf..689b26ce 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_03_gardner_altman_unpaired_hedges_g.png and b/nbs/tests/mpl_image_tests/baseline_images/test_03_gardner_altman_unpaired_hedges_g.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_04_gardner_altman_paired_hedges_g.png b/nbs/tests/mpl_image_tests/baseline_images/test_04_gardner_altman_paired_hedges_g.png index 3052b159..37c2cfd6 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_04_gardner_altman_paired_hedges_g.png and b/nbs/tests/mpl_image_tests/baseline_images/test_04_gardner_altman_paired_hedges_g.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_04_gardner_altman_paired_meandiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_04_gardner_altman_paired_meandiff.png index e86977a6..6d3f0f6d 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_04_gardner_altman_paired_meandiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_04_gardner_altman_paired_meandiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_05_cummings_two_group_unpaired_meandiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_05_cummings_two_group_unpaired_meandiff.png index e80a42b1..47bcf864 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_05_cummings_two_group_unpaired_meandiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_05_cummings_two_group_unpaired_meandiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_06_cummings_two_group_paired_meandiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_06_cummings_two_group_paired_meandiff.png index 5571d031..d8c15a9d 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_06_cummings_two_group_paired_meandiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_06_cummings_two_group_paired_meandiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_07_cummings_multi_group_unpaired.png b/nbs/tests/mpl_image_tests/baseline_images/test_07_cummings_multi_group_unpaired.png index 44599675..c32e34bd 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_07_cummings_multi_group_unpaired.png and b/nbs/tests/mpl_image_tests/baseline_images/test_07_cummings_multi_group_unpaired.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_08_cummings_multi_group_paired.png b/nbs/tests/mpl_image_tests/baseline_images/test_08_cummings_multi_group_paired.png index 8aeaac2b..86fff480 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_08_cummings_multi_group_paired.png and b/nbs/tests/mpl_image_tests/baseline_images/test_08_cummings_multi_group_paired.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_09_cummings_shared_control.png b/nbs/tests/mpl_image_tests/baseline_images/test_09_cummings_shared_control.png index 5c8dc16f..99a6e2aa 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_09_cummings_shared_control.png and b/nbs/tests/mpl_image_tests/baseline_images/test_09_cummings_shared_control.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_101_gardner_altman_unpaired_propdiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_101_gardner_altman_unpaired_propdiff.png index b4c3a015..79cb7092 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_101_gardner_altman_unpaired_propdiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_101_gardner_altman_unpaired_propdiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_103_cummings_two_group_unpaired_propdiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_103_cummings_two_group_unpaired_propdiff.png index bcece5c2..260147a4 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_103_cummings_two_group_unpaired_propdiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_103_cummings_two_group_unpaired_propdiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_105_cummings_multi_group_unpaired_propdiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_105_cummings_multi_group_unpaired_propdiff.png index f3990915..9b1454b2 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_105_cummings_multi_group_unpaired_propdiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_105_cummings_multi_group_unpaired_propdiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_106_cummings_shared_control_propdiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_106_cummings_shared_control_propdiff.png index b1efc8b8..83f578ce 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_106_cummings_shared_control_propdiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_106_cummings_shared_control_propdiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_107_cummings_multi_groups_propdiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_107_cummings_multi_groups_propdiff.png index e03d2a08..5df112c6 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_107_cummings_multi_groups_propdiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_107_cummings_multi_groups_propdiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_109_gardner_altman_ylabel.png b/nbs/tests/mpl_image_tests/baseline_images/test_109_gardner_altman_ylabel.png index 2a8e3fa4..4641f875 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_109_gardner_altman_ylabel.png and b/nbs/tests/mpl_image_tests/baseline_images/test_109_gardner_altman_ylabel.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_10_cummings_multi_groups.png b/nbs/tests/mpl_image_tests/baseline_images/test_10_cummings_multi_groups.png index ff99efa0..c485546e 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_10_cummings_multi_groups.png and b/nbs/tests/mpl_image_tests/baseline_images/test_10_cummings_multi_groups.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_110_change_fig_size.png b/nbs/tests/mpl_image_tests/baseline_images/test_110_change_fig_size.png index ed00258f..9d941ce2 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_110_change_fig_size.png and b/nbs/tests/mpl_image_tests/baseline_images/test_110_change_fig_size.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_111_change_palette_b.png b/nbs/tests/mpl_image_tests/baseline_images/test_111_change_palette_b.png index d43750e6..2eabd427 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_111_change_palette_b.png and b/nbs/tests/mpl_image_tests/baseline_images/test_111_change_palette_b.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_112_change_palette_c.png b/nbs/tests/mpl_image_tests/baseline_images/test_112_change_palette_c.png index 7a068a8d..2e0d86e0 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_112_change_palette_c.png and b/nbs/tests/mpl_image_tests/baseline_images/test_112_change_palette_c.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_113_desat.png b/nbs/tests/mpl_image_tests/baseline_images/test_113_desat.png index 63a3e313..94f9747c 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_113_desat.png and b/nbs/tests/mpl_image_tests/baseline_images/test_113_desat.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_114_change_ylims.png b/nbs/tests/mpl_image_tests/baseline_images/test_114_change_ylims.png index 6299d03f..46131938 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_114_change_ylims.png and b/nbs/tests/mpl_image_tests/baseline_images/test_114_change_ylims.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_115_invert_ylim.png b/nbs/tests/mpl_image_tests/baseline_images/test_115_invert_ylim.png index a16c49be..5025e477 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_115_invert_ylim.png and b/nbs/tests/mpl_image_tests/baseline_images/test_115_invert_ylim.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_116_ticker_gardner_altman.png b/nbs/tests/mpl_image_tests/baseline_images/test_116_ticker_gardner_altman.png index 2d1bb1d7..fd5a79c4 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_116_ticker_gardner_altman.png and b/nbs/tests/mpl_image_tests/baseline_images/test_116_ticker_gardner_altman.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_117_err_color.png b/nbs/tests/mpl_image_tests/baseline_images/test_117_err_color.png index 9d7b655f..01184f12 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_117_err_color.png and b/nbs/tests/mpl_image_tests/baseline_images/test_117_err_color.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_118_cummings_two_group_unpaired_meandiff_bar_width.png b/nbs/tests/mpl_image_tests/baseline_images/test_118_cummings_two_group_unpaired_meandiff_bar_width.png index 5b61946c..5d0b3938 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_118_cummings_two_group_unpaired_meandiff_bar_width.png and b/nbs/tests/mpl_image_tests/baseline_images/test_118_cummings_two_group_unpaired_meandiff_bar_width.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_119_wide_df_nan.png b/nbs/tests/mpl_image_tests/baseline_images/test_119_wide_df_nan.png index 38d99689..2880dd45 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_119_wide_df_nan.png and b/nbs/tests/mpl_image_tests/baseline_images/test_119_wide_df_nan.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_11_inset_plots.png b/nbs/tests/mpl_image_tests/baseline_images/test_11_inset_plots.png index a93e8a8d..3e0fbadd 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_11_inset_plots.png and b/nbs/tests/mpl_image_tests/baseline_images/test_11_inset_plots.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_120_long_df_nan.png b/nbs/tests/mpl_image_tests/baseline_images/test_120_long_df_nan.png index 38d99689..2880dd45 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_120_long_df_nan.png and b/nbs/tests/mpl_image_tests/baseline_images/test_120_long_df_nan.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_121_cohens_h_gardner_altman.png b/nbs/tests/mpl_image_tests/baseline_images/test_121_cohens_h_gardner_altman.png index 21a7c950..7ae94529 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_121_cohens_h_gardner_altman.png and b/nbs/tests/mpl_image_tests/baseline_images/test_121_cohens_h_gardner_altman.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_122_cohens_h_cummings.png b/nbs/tests/mpl_image_tests/baseline_images/test_122_cohens_h_cummings.png index 5c21a69c..61d934c9 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_122_cohens_h_cummings.png and b/nbs/tests/mpl_image_tests/baseline_images/test_122_cohens_h_cummings.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_123_sankey_gardner_altman.png b/nbs/tests/mpl_image_tests/baseline_images/test_123_sankey_gardner_altman.png index 698aa855..d7db7ea2 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_123_sankey_gardner_altman.png and b/nbs/tests/mpl_image_tests/baseline_images/test_123_sankey_gardner_altman.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_124_sankey_cummings.png b/nbs/tests/mpl_image_tests/baseline_images/test_124_sankey_cummings.png index d93e223d..6014c7ad 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_124_sankey_cummings.png and b/nbs/tests/mpl_image_tests/baseline_images/test_124_sankey_cummings.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_125_sankey_2paired_groups.png b/nbs/tests/mpl_image_tests/baseline_images/test_125_sankey_2paired_groups.png index 311f892c..e6cc78d7 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_125_sankey_2paired_groups.png and b/nbs/tests/mpl_image_tests/baseline_images/test_125_sankey_2paired_groups.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_126_sankey_2sequential_groups.png b/nbs/tests/mpl_image_tests/baseline_images/test_126_sankey_2sequential_groups.png index 311f892c..e6cc78d7 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_126_sankey_2sequential_groups.png and b/nbs/tests/mpl_image_tests/baseline_images/test_126_sankey_2sequential_groups.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_127_sankey_multi_group_paired.png b/nbs/tests/mpl_image_tests/baseline_images/test_127_sankey_multi_group_paired.png index 82e42603..097c8668 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_127_sankey_multi_group_paired.png and b/nbs/tests/mpl_image_tests/baseline_images/test_127_sankey_multi_group_paired.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_128_sankey_transparency.png b/nbs/tests/mpl_image_tests/baseline_images/test_128_sankey_transparency.png index 1daf9526..334b045f 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_128_sankey_transparency.png and b/nbs/tests/mpl_image_tests/baseline_images/test_128_sankey_transparency.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_129_zero_to_zero.png b/nbs/tests/mpl_image_tests/baseline_images/test_129_zero_to_zero.png index 279f9c27..279bc74b 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_129_zero_to_zero.png and b/nbs/tests/mpl_image_tests/baseline_images/test_129_zero_to_zero.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_12_gardner_altman_ylabel.png b/nbs/tests/mpl_image_tests/baseline_images/test_12_gardner_altman_ylabel.png index f18c3899..33c4ebc6 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_12_gardner_altman_ylabel.png and b/nbs/tests/mpl_image_tests/baseline_images/test_12_gardner_altman_ylabel.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_130_zero_to_one.png b/nbs/tests/mpl_image_tests/baseline_images/test_130_zero_to_one.png index 99a890cf..da88d890 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_130_zero_to_one.png and b/nbs/tests/mpl_image_tests/baseline_images/test_130_zero_to_one.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_131_one_to_zero.png b/nbs/tests/mpl_image_tests/baseline_images/test_131_one_to_zero.png index 4f6e6351..a17263f3 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_131_one_to_zero.png and b/nbs/tests/mpl_image_tests/baseline_images/test_131_one_to_zero.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_132_shared_control_sankey_off.png b/nbs/tests/mpl_image_tests/baseline_images/test_132_shared_control_sankey_off.png index 07ca4d9e..94f850d6 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_132_shared_control_sankey_off.png and b/nbs/tests/mpl_image_tests/baseline_images/test_132_shared_control_sankey_off.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_133_shared_control_flow_off.png b/nbs/tests/mpl_image_tests/baseline_images/test_133_shared_control_flow_off.png index 51fad57b..5c31d70c 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_133_shared_control_flow_off.png and b/nbs/tests/mpl_image_tests/baseline_images/test_133_shared_control_flow_off.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_134_separate_control_sankey_off.png b/nbs/tests/mpl_image_tests/baseline_images/test_134_separate_control_sankey_off.png index c3391251..c0036635 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_134_separate_control_sankey_off.png and b/nbs/tests/mpl_image_tests/baseline_images/test_134_separate_control_sankey_off.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_135_separate_control_flow_off.png b/nbs/tests/mpl_image_tests/baseline_images/test_135_separate_control_flow_off.png index 9d3c1bc5..392cdfb6 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_135_separate_control_flow_off.png and b/nbs/tests/mpl_image_tests/baseline_images/test_135_separate_control_flow_off.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_136_style_sheets.png b/nbs/tests/mpl_image_tests/baseline_images/test_136_style_sheets.png index 297e1b43..93475edd 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_136_style_sheets.png and b/nbs/tests/mpl_image_tests/baseline_images/test_136_style_sheets.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_13_multi_2group_color.png b/nbs/tests/mpl_image_tests/baseline_images/test_13_multi_2group_color.png index 12a110a8..f146a6a8 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_13_multi_2group_color.png and b/nbs/tests/mpl_image_tests/baseline_images/test_13_multi_2group_color.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_14_gardner_altman_paired_color.png b/nbs/tests/mpl_image_tests/baseline_images/test_14_gardner_altman_paired_color.png index 4b293951..2b9aea2f 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_14_gardner_altman_paired_color.png and b/nbs/tests/mpl_image_tests/baseline_images/test_14_gardner_altman_paired_color.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_15_change_palette_a.png b/nbs/tests/mpl_image_tests/baseline_images/test_15_change_palette_a.png index 46533b5f..1a1320e5 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_15_change_palette_a.png and b/nbs/tests/mpl_image_tests/baseline_images/test_15_change_palette_a.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_16_change_palette_b.png b/nbs/tests/mpl_image_tests/baseline_images/test_16_change_palette_b.png index 7a1755e1..c1e4ae2e 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_16_change_palette_b.png and b/nbs/tests/mpl_image_tests/baseline_images/test_16_change_palette_b.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_17_change_palette_c.png b/nbs/tests/mpl_image_tests/baseline_images/test_17_change_palette_c.png index 3d91180c..f04b0dc9 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_17_change_palette_c.png and b/nbs/tests/mpl_image_tests/baseline_images/test_17_change_palette_c.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_18_desat.png b/nbs/tests/mpl_image_tests/baseline_images/test_18_desat.png index 67aa7c9d..53095da8 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_18_desat.png and b/nbs/tests/mpl_image_tests/baseline_images/test_18_desat.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_19_dot_sizes.png b/nbs/tests/mpl_image_tests/baseline_images/test_19_dot_sizes.png index 40cfeabe..eed9a9b3 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_19_dot_sizes.png and b/nbs/tests/mpl_image_tests/baseline_images/test_19_dot_sizes.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_201_forest_plot_no_colorpalette.png b/nbs/tests/mpl_image_tests/baseline_images/test_201_forest_plot_no_colorpalette.png index 0926bddf..35fef1a4 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_201_forest_plot_no_colorpalette.png and b/nbs/tests/mpl_image_tests/baseline_images/test_201_forest_plot_no_colorpalette.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_202_forest_plot_with_colorpalette.png b/nbs/tests/mpl_image_tests/baseline_images/test_202_forest_plot_with_colorpalette.png index 12c37b1c..41f0a339 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_202_forest_plot_with_colorpalette.png and b/nbs/tests/mpl_image_tests/baseline_images/test_202_forest_plot_with_colorpalette.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_203_horizontal_forest_plot_no_colorpalette.png b/nbs/tests/mpl_image_tests/baseline_images/test_203_horizontal_forest_plot_no_colorpalette.png index 88ed2da6..6f7c946f 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_203_horizontal_forest_plot_no_colorpalette.png and b/nbs/tests/mpl_image_tests/baseline_images/test_203_horizontal_forest_plot_no_colorpalette.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_204_horizontal_forest_plot_with_colorpalette.png b/nbs/tests/mpl_image_tests/baseline_images/test_204_horizontal_forest_plot_with_colorpalette.png index b55d9f25..0456368e 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_204_horizontal_forest_plot_with_colorpalette.png and b/nbs/tests/mpl_image_tests/baseline_images/test_204_horizontal_forest_plot_with_colorpalette.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_205_forest_mini_meta_horizontal.png b/nbs/tests/mpl_image_tests/baseline_images/test_205_forest_mini_meta_horizontal.png index d429c7ea..d8d273ae 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_205_forest_mini_meta_horizontal.png and b/nbs/tests/mpl_image_tests/baseline_images/test_205_forest_mini_meta_horizontal.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_206_forest_mini_meta.png b/nbs/tests/mpl_image_tests/baseline_images/test_206_forest_mini_meta.png index ad1dc77c..68984662 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_206_forest_mini_meta.png and b/nbs/tests/mpl_image_tests/baseline_images/test_206_forest_mini_meta.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_207_gardner_altman_meandiff_empty_circle.png b/nbs/tests/mpl_image_tests/baseline_images/test_207_gardner_altman_meandiff_empty_circle.png new file mode 100644 index 00000000..f8f4d10d Binary files /dev/null and b/nbs/tests/mpl_image_tests/baseline_images/test_207_gardner_altman_meandiff_empty_circle.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_208_cummings_two_group_unpaired_meandiff_empty_circle.png b/nbs/tests/mpl_image_tests/baseline_images/test_208_cummings_two_group_unpaired_meandiff_empty_circle.png new file mode 100644 index 00000000..f908be45 Binary files /dev/null and b/nbs/tests/mpl_image_tests/baseline_images/test_208_cummings_two_group_unpaired_meandiff_empty_circle.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_209_cummings_shared_control_meandiff_empty_circle.png b/nbs/tests/mpl_image_tests/baseline_images/test_209_cummings_shared_control_meandiff_empty_circle.png new file mode 100644 index 00000000..ec1d2919 Binary files /dev/null and b/nbs/tests/mpl_image_tests/baseline_images/test_209_cummings_shared_control_meandiff_empty_circle.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_20_change_ylims.png b/nbs/tests/mpl_image_tests/baseline_images/test_20_change_ylims.png index 879873a6..5f5b42c3 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_20_change_ylims.png and b/nbs/tests/mpl_image_tests/baseline_images/test_20_change_ylims.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_210_cummings_multi_groups_meandiff_empty_circle.png b/nbs/tests/mpl_image_tests/baseline_images/test_210_cummings_multi_groups_meandiff_empty_circle.png new file mode 100644 index 00000000..d682eb97 Binary files /dev/null and b/nbs/tests/mpl_image_tests/baseline_images/test_210_cummings_multi_groups_meandiff_empty_circle.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_211_cummings_multi_2_group_meandiff_empty_circle.png b/nbs/tests/mpl_image_tests/baseline_images/test_211_cummings_multi_2_group_meandiff_empty_circle.png new file mode 100644 index 00000000..0b324642 Binary files /dev/null and b/nbs/tests/mpl_image_tests/baseline_images/test_211_cummings_multi_2_group_meandiff_empty_circle.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_212_cummings_unpaired_delta_delta_meandiff_empty_circle.png b/nbs/tests/mpl_image_tests/baseline_images/test_212_cummings_unpaired_delta_delta_meandiff_empty_circle.png new file mode 100644 index 00000000..20cfda17 Binary files /dev/null and b/nbs/tests/mpl_image_tests/baseline_images/test_212_cummings_unpaired_delta_delta_meandiff_empty_circle.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_213_cummings_unpaired_mini_meta_meandiff_empty_circle.png b/nbs/tests/mpl_image_tests/baseline_images/test_213_cummings_unpaired_mini_meta_meandiff_empty_circle.png new file mode 100644 index 00000000..8bb6a91e Binary files /dev/null and b/nbs/tests/mpl_image_tests/baseline_images/test_213_cummings_unpaired_mini_meta_meandiff_empty_circle.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_21_invert_ylim.png b/nbs/tests/mpl_image_tests/baseline_images/test_21_invert_ylim.png index 26b7db6d..6a91213b 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_21_invert_ylim.png and b/nbs/tests/mpl_image_tests/baseline_images/test_21_invert_ylim.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_22_ticker_gardner_altman.png b/nbs/tests/mpl_image_tests/baseline_images/test_22_ticker_gardner_altman.png index ff074e1d..19b28ad0 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_22_ticker_gardner_altman.png and b/nbs/tests/mpl_image_tests/baseline_images/test_22_ticker_gardner_altman.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_23_ticker_cumming.png b/nbs/tests/mpl_image_tests/baseline_images/test_23_ticker_cumming.png index 9b5604a1..a74d8fb2 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_23_ticker_cumming.png and b/nbs/tests/mpl_image_tests/baseline_images/test_23_ticker_cumming.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_24_wide_df_nan.png b/nbs/tests/mpl_image_tests/baseline_images/test_24_wide_df_nan.png index f7b0739f..577fa498 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_24_wide_df_nan.png and b/nbs/tests/mpl_image_tests/baseline_images/test_24_wide_df_nan.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_25_long_df_nan.png b/nbs/tests/mpl_image_tests/baseline_images/test_25_long_df_nan.png index f7b0739f..577fa498 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_25_long_df_nan.png and b/nbs/tests/mpl_image_tests/baseline_images/test_25_long_df_nan.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_26_slopegraph_kwargs.png b/nbs/tests/mpl_image_tests/baseline_images/test_26_slopegraph_kwargs.png index 4744c6da..87359f5f 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_26_slopegraph_kwargs.png and b/nbs/tests/mpl_image_tests/baseline_images/test_26_slopegraph_kwargs.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_27_gardner_altman_reflines_kwargs.png b/nbs/tests/mpl_image_tests/baseline_images/test_27_gardner_altman_reflines_kwargs.png index 237637f3..73336ca7 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_27_gardner_altman_reflines_kwargs.png and b/nbs/tests/mpl_image_tests/baseline_images/test_27_gardner_altman_reflines_kwargs.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_28_unpaired_cumming_reflines_kwargs.png b/nbs/tests/mpl_image_tests/baseline_images/test_28_unpaired_cumming_reflines_kwargs.png index 6697e15b..61c7507a 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_28_unpaired_cumming_reflines_kwargs.png and b/nbs/tests/mpl_image_tests/baseline_images/test_28_unpaired_cumming_reflines_kwargs.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_29_paired_cumming_slopegraph_reflines_kwargs.png b/nbs/tests/mpl_image_tests/baseline_images/test_29_paired_cumming_slopegraph_reflines_kwargs.png index ae1a9787..ea4b426d 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_29_paired_cumming_slopegraph_reflines_kwargs.png and b/nbs/tests/mpl_image_tests/baseline_images/test_29_paired_cumming_slopegraph_reflines_kwargs.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_30_sequential_cumming_slopegraph.png b/nbs/tests/mpl_image_tests/baseline_images/test_30_sequential_cumming_slopegraph.png index 1cb35bc0..6975e1c5 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_30_sequential_cumming_slopegraph.png and b/nbs/tests/mpl_image_tests/baseline_images/test_30_sequential_cumming_slopegraph.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_31_baseline_cumming_slopegraph.png b/nbs/tests/mpl_image_tests/baseline_images/test_31_baseline_cumming_slopegraph.png index abe3580d..893ca46e 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_31_baseline_cumming_slopegraph.png and b/nbs/tests/mpl_image_tests/baseline_images/test_31_baseline_cumming_slopegraph.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_47_cummings_unpaired_delta_delta_meandiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_47_cummings_unpaired_delta_delta_meandiff.png index 2001ce6f..d2cf430d 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_47_cummings_unpaired_delta_delta_meandiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_47_cummings_unpaired_delta_delta_meandiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_48_cummings_sequential_delta_delta_meandiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_48_cummings_sequential_delta_delta_meandiff.png index 53376f23..824dff19 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_48_cummings_sequential_delta_delta_meandiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_48_cummings_sequential_delta_delta_meandiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_49_cummings_baseline_delta_delta_meandiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_49_cummings_baseline_delta_delta_meandiff.png index 53376f23..824dff19 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_49_cummings_baseline_delta_delta_meandiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_49_cummings_baseline_delta_delta_meandiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_50_delta_plot_ylabel.png b/nbs/tests/mpl_image_tests/baseline_images/test_50_delta_plot_ylabel.png index d94de0a3..6ef6fc16 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_50_delta_plot_ylabel.png and b/nbs/tests/mpl_image_tests/baseline_images/test_50_delta_plot_ylabel.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_51_delta_plot_change_palette_a.png b/nbs/tests/mpl_image_tests/baseline_images/test_51_delta_plot_change_palette_a.png index 97b9e645..ea4adfc4 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_51_delta_plot_change_palette_a.png and b/nbs/tests/mpl_image_tests/baseline_images/test_51_delta_plot_change_palette_a.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_52_delta_specified.png b/nbs/tests/mpl_image_tests/baseline_images/test_52_delta_specified.png index bc07a8bb..ed3432ef 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_52_delta_specified.png and b/nbs/tests/mpl_image_tests/baseline_images/test_52_delta_specified.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_53_delta_change_ylims.png b/nbs/tests/mpl_image_tests/baseline_images/test_53_delta_change_ylims.png index 625d2dd4..794a0548 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_53_delta_change_ylims.png and b/nbs/tests/mpl_image_tests/baseline_images/test_53_delta_change_ylims.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_54_delta_invert_ylim.png b/nbs/tests/mpl_image_tests/baseline_images/test_54_delta_invert_ylim.png index 818e2125..ef63c04d 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_54_delta_invert_ylim.png and b/nbs/tests/mpl_image_tests/baseline_images/test_54_delta_invert_ylim.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_55_delta_median_diff.png b/nbs/tests/mpl_image_tests/baseline_images/test_55_delta_median_diff.png index e339eaac..2bf942a6 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_55_delta_median_diff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_55_delta_median_diff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_56_delta_cohens_d.png b/nbs/tests/mpl_image_tests/baseline_images/test_56_delta_cohens_d.png index f70b5423..3cb8a185 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_56_delta_cohens_d.png and b/nbs/tests/mpl_image_tests/baseline_images/test_56_delta_cohens_d.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_57_delta_show_delta2.png b/nbs/tests/mpl_image_tests/baseline_images/test_57_delta_show_delta2.png index 4386758e..d04fa036 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_57_delta_show_delta2.png and b/nbs/tests/mpl_image_tests/baseline_images/test_57_delta_show_delta2.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_58_delta_axes_invert_ylim.png b/nbs/tests/mpl_image_tests/baseline_images/test_58_delta_axes_invert_ylim.png index 238e4827..8bd09757 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_58_delta_axes_invert_ylim.png and b/nbs/tests/mpl_image_tests/baseline_images/test_58_delta_axes_invert_ylim.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_59_delta_axes_invert_ylim_not_showing_delta2.png b/nbs/tests/mpl_image_tests/baseline_images/test_59_delta_axes_invert_ylim_not_showing_delta2.png index 4386758e..d04fa036 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_59_delta_axes_invert_ylim_not_showing_delta2.png and b/nbs/tests/mpl_image_tests/baseline_images/test_59_delta_axes_invert_ylim_not_showing_delta2.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_60_cummings_unpaired_mini_meta_meandiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_60_cummings_unpaired_mini_meta_meandiff.png index 05675a6f..cb2356ad 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_60_cummings_unpaired_mini_meta_meandiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_60_cummings_unpaired_mini_meta_meandiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_61_cummings_sequential_mini_meta_meandiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_61_cummings_sequential_mini_meta_meandiff.png index 9fde7c9e..da3a9580 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_61_cummings_sequential_mini_meta_meandiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_61_cummings_sequential_mini_meta_meandiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_62_cummings_baseline_mini_meta_meandiff.png b/nbs/tests/mpl_image_tests/baseline_images/test_62_cummings_baseline_mini_meta_meandiff.png index 9fde7c9e..da3a9580 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_62_cummings_baseline_mini_meta_meandiff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_62_cummings_baseline_mini_meta_meandiff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_63_mini_meta_plot_ylabel.png b/nbs/tests/mpl_image_tests/baseline_images/test_63_mini_meta_plot_ylabel.png index b86ff496..09edd730 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_63_mini_meta_plot_ylabel.png and b/nbs/tests/mpl_image_tests/baseline_images/test_63_mini_meta_plot_ylabel.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_64_mini_meta_plot_change_palette_a.png b/nbs/tests/mpl_image_tests/baseline_images/test_64_mini_meta_plot_change_palette_a.png index e8ccb3f0..b9975211 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_64_mini_meta_plot_change_palette_a.png and b/nbs/tests/mpl_image_tests/baseline_images/test_64_mini_meta_plot_change_palette_a.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_65_mini_meta_dot_sizes.png b/nbs/tests/mpl_image_tests/baseline_images/test_65_mini_meta_dot_sizes.png index 21dfc2ea..a713453e 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_65_mini_meta_dot_sizes.png and b/nbs/tests/mpl_image_tests/baseline_images/test_65_mini_meta_dot_sizes.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_66_mini_meta_change_ylims.png b/nbs/tests/mpl_image_tests/baseline_images/test_66_mini_meta_change_ylims.png index 4189ad49..7fba8be7 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_66_mini_meta_change_ylims.png and b/nbs/tests/mpl_image_tests/baseline_images/test_66_mini_meta_change_ylims.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_67_mini_meta_invert_ylim.png b/nbs/tests/mpl_image_tests/baseline_images/test_67_mini_meta_invert_ylim.png index 9e1992fe..9ff26211 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_67_mini_meta_invert_ylim.png and b/nbs/tests/mpl_image_tests/baseline_images/test_67_mini_meta_invert_ylim.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_68_mini_meta_median_diff.png b/nbs/tests/mpl_image_tests/baseline_images/test_68_mini_meta_median_diff.png index 6a42eb52..d1bc094b 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_68_mini_meta_median_diff.png and b/nbs/tests/mpl_image_tests/baseline_images/test_68_mini_meta_median_diff.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_69_mini_meta_cohens_d.png b/nbs/tests/mpl_image_tests/baseline_images/test_69_mini_meta_cohens_d.png index e68c2983..7eeb64e6 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_69_mini_meta_cohens_d.png and b/nbs/tests/mpl_image_tests/baseline_images/test_69_mini_meta_cohens_d.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_70_mini_meta_not_show.png b/nbs/tests/mpl_image_tests/baseline_images/test_70_mini_meta_not_show.png index bc0bf7f4..49f75004 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_70_mini_meta_not_show.png and b/nbs/tests/mpl_image_tests/baseline_images/test_70_mini_meta_not_show.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_71_unpaired_delta_g.png b/nbs/tests/mpl_image_tests/baseline_images/test_71_unpaired_delta_g.png index 7823d235..b9d3dda7 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_71_unpaired_delta_g.png and b/nbs/tests/mpl_image_tests/baseline_images/test_71_unpaired_delta_g.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_72_sequential_delta_g.png b/nbs/tests/mpl_image_tests/baseline_images/test_72_sequential_delta_g.png index 53376f23..824dff19 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_72_sequential_delta_g.png and b/nbs/tests/mpl_image_tests/baseline_images/test_72_sequential_delta_g.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_73_baseline_delta_g.png b/nbs/tests/mpl_image_tests/baseline_images/test_73_baseline_delta_g.png index 53376f23..824dff19 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_73_baseline_delta_g.png and b/nbs/tests/mpl_image_tests/baseline_images/test_73_baseline_delta_g.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_99_style_sheets.png b/nbs/tests/mpl_image_tests/baseline_images/test_99_style_sheets.png index dd9a202a..112f96aa 100644 Binary files a/nbs/tests/mpl_image_tests/baseline_images/test_99_style_sheets.png and b/nbs/tests/mpl_image_tests/baseline_images/test_99_style_sheets.png differ diff --git a/nbs/tests/mpl_image_tests/test_plot_aesthetics.py b/nbs/tests/mpl_image_tests/test_plot_aesthetics.py new file mode 100644 index 00000000..3c74f069 --- /dev/null +++ b/nbs/tests/mpl_image_tests/test_plot_aesthetics.py @@ -0,0 +1,175 @@ +import pytest +import numpy as np +import pandas as pd +from scipy.stats import norm + + +import matplotlib as mpl + +mpl.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.ticker as Ticker + +from dabest._api import load + + +def create_demo_dataset(seed=9999, N=20): + import numpy as np + import pandas as pd + from scipy.stats import norm # Used in generation of populations. + + np.random.seed(9999) # Fix the seed so the results are replicable. + # pop_size = 10000 # Size of each population. + + # Create samples + c1 = norm.rvs(loc=3, scale=0.4, size=N) + c2 = norm.rvs(loc=3.5, scale=0.75, size=N) + c3 = norm.rvs(loc=3.25, scale=0.4, size=N) + + t1 = norm.rvs(loc=3.5, scale=0.5, size=N) + t2 = norm.rvs(loc=2.5, scale=0.6, size=N) + t3 = norm.rvs(loc=3, scale=0.75, size=N) + t4 = norm.rvs(loc=3.5, scale=0.75, size=N) + t5 = norm.rvs(loc=3.25, scale=0.4, size=N) + t6 = norm.rvs(loc=3.25, scale=0.4, size=N) + + # Add a `gender` column for coloring the data. + females = np.repeat("Female", N / 2).tolist() + males = np.repeat("Male", N / 2).tolist() + gender = females + males + + # Add an `id` column for paired data plotting. + id_col = pd.Series(range(1, N + 1)) + + # Combine samples and gender into a DataFrame. + df = pd.DataFrame( + { + "Control 1": c1, + "Test 1": t1, + "Control 2": c2, + "Test 2": t2, + "Control 3": c3, + "Test 3": t3, + "Test 4": t4, + "Test 5": t5, + "Test 6": t6, + "Gender": gender, + "ID": id_col, + } + ) + + return df + + +def create_demo_dataset_delta(seed=9999, N=20): + + import numpy as np + import pandas as pd + from scipy.stats import norm # Used in generation of populations. + + np.random.seed(seed) # Fix the seed so the results are replicable. + # pop_size = 10000 # Size of each population. + + from scipy.stats import norm # Used in generation of populations. + + # Create samples + y = norm.rvs(loc=3, scale=0.4, size=N * 4) + y[N : 2 * N] = y[N : 2 * N] + 1 + y[2 * N : 3 * N] = y[2 * N : 3 * N] - 0.5 + + # Add drug column + t1 = np.repeat("Placebo", N * 2).tolist() + t2 = np.repeat("Drug", N * 2).tolist() + treatment = t1 + t2 + + # Add a `rep` column as the first variable for the 2 replicates of experiments done + rep = [] + for i in range(N * 2): + rep.append("Rep1") + rep.append("Rep2") + + # Add a `genotype` column as the second variable + wt = np.repeat("W", N).tolist() + mt = np.repeat("M", N).tolist() + wt2 = np.repeat("W", N).tolist() + mt2 = np.repeat("M", N).tolist() + + genotype = wt + mt + wt2 + mt2 + + # Add an `id` column for paired data plotting. + id = list(range(0, N * 2)) + id_col = id + id + + # Combine all columns into a DataFrame. + df = pd.DataFrame( + {"ID": id_col, "Rep": rep, "Genotype": genotype, "Treatment": treatment, "Y": y} + ) + return df + + +df = create_demo_dataset() +df_delta = create_demo_dataset_delta() + +two_groups_unpaired = load(df, idx=("Control 1", "Test 1")) + +multi_2group = load( + df, + idx=( + ( + "Control 1", + "Test 1", + ), + ("Control 2", "Test 2"), + ), +) + +shared_control = load( + df, idx=("Control 1", "Test 1", "Test 2", "Test 3", "Test 4", "Test 5", "Test 6") +) + +multi_groups = load( + df, + idx=( + ( + "Control 1", + "Test 1", + ), + ("Control 2", "Test 2", "Test 3"), + ("Control 3", "Test 4", "Test 5", "Test 6"), + ), +) + +unpaired_delta_delta = load( + data=df_delta, x=["Genotype", "Genotype"], y="Y", delta2=True, experiment="Treatment" +) + +unpaired_mini_meta = load(df, idx=(("Control 1", "Test 1"), ("Control 2", "Test 2"), ("Control 3", "Test 3")), + mini_meta=True) + +@pytest.mark.mpl_image_compare(tolerance=8) +def test_207_gardner_altman_meandiff_empty_circle(): + return two_groups_unpaired.mean_diff.plot(empty_circle=True); + +@pytest.mark.mpl_image_compare(tolerance=8) +def test_208_cummings_two_group_unpaired_meandiff_empty_circle(): + return two_groups_unpaired.mean_diff.plot(empty_circle=True, float_contrast=False); + +@pytest.mark.mpl_image_compare(tolerance=8) +def test_209_cummings_shared_control_meandiff_empty_circle(): + return shared_control.mean_diff.plot(empty_circle=True); + +@pytest.mark.mpl_image_compare(tolerance=8) +def test_210_cummings_multi_groups_meandiff_empty_circle(): + return multi_groups.mean_diff.plot(empty_circle=True); + +@pytest.mark.mpl_image_compare(tolerance=8) +def test_211_cummings_multi_2_group_meandiff_empty_circle(): + return multi_2group.mean_diff.plot(empty_circle=True); + +@pytest.mark.mpl_image_compare(tolerance=8) +def test_212_cummings_unpaired_delta_delta_meandiff_empty_circle(): + return unpaired_delta_delta.mean_diff.plot(empty_circle=True); + +@pytest.mark.mpl_image_compare(tolerance=8) +def test_213_cummings_unpaired_mini_meta_meandiff_empty_circle(): + return unpaired_mini_meta.mean_diff.plot(empty_circle=True); diff --git a/nbs/tests/test_plot_tools.py b/nbs/tests/test_plot_tools.py index b47dba7f..fb07f9c9 100644 --- a/nbs/tests/test_plot_tools.py +++ b/nbs/tests/test_plot_tools.py @@ -94,6 +94,7 @@ def test_check_data_matches_labels(): ("jitter", None, "`jitter` must be a scalar or float.", ValueError), ("is_drop_gutter", None, "`is_drop_gutter` must be a boolean.", ValueError), ("gutter_limit", None, "`gutter_limit` must be a scalar or float.", ValueError), + ("filled", 1, "`filled` must be a boolean, list or tuple.", ValueError), # More thorough input validation checks ("x", "a", "a is not a column in `data`.", IndexError), @@ -104,7 +105,9 @@ def test_check_data_matches_labels(): ("palette", {"Control 1": " "}, "The color mapping for Control 1 in `palette` is an empty string. It must contain a color name.", ValueError), ("palette", {"Control 3": "black"}, "Control 3 in `palette` is not in the 'group' column of `data`.", IndexError), # TODO: to add palette validation testing for when color_col is hue - ("side", "top", "Invalid `side`. Must be one of 'center', 'right', or 'left'.", ValueError) + ("side", "top", "Invalid `side`. Must be one of 'center', 'right', or 'left'.", ValueError), + ("filled", [True, "a"], "All values in `filled` must be a boolean.", ValueError), + ("filled", [True], "There are 2 unique values in `x` column in `data` but `filled` has a length of 1.", ValueError), ]) def test_swarmplot_input_error_handling(param_name, param_value, error_msg, error_type): with pytest.raises(error_type) as excinfo: @@ -120,6 +123,7 @@ def test_swarmplot_input_error_handling(param_name, param_value, error_msg, erro size=5 if param_name != "size" else param_value, side="center" if param_name != "side" else param_value, jitter=1 if param_name != "jitter" else param_value, + filled=True if param_name != "filled" else param_value, is_drop_gutter=True if param_name != "is_drop_gutter" else param_value, gutter_limit=0.5 if param_name != "gutter_limit" else param_value, ) diff --git a/test.py b/test.py new file mode 100644 index 00000000..105136f1 --- /dev/null +++ b/test.py @@ -0,0 +1,88 @@ +import numpy as np +from scipy.stats import norm +import pandas as pd +import matplotlib as mpl +import os +from pathlib import Path + +import matplotlib.ticker as Ticker +import matplotlib.pyplot as plt + +from dabest._api import load + +import dabest + +columns = [1, 2.0] +columns_str = ["1", "2.0"] +# create a test database +N = 100 +df = pd.DataFrame(np.vstack([np.random.normal(loc=i, size=(N,)) for i in range(len(columns))]).T, columns=columns_str) +females = np.repeat("Female", N / 2).tolist() +males = np.repeat("Male", N / 2).tolist() +df['gender'] = females + males + +# Add an `id` column for paired data plotting. +df['ID'] = pd.Series(range(1, N + 1)) + + +db = dabest.load(data=df, idx=columns_str, paired="baseline", id_col="ID") +print(db.mean_diff) +db.mean_diff.plot(); + +# def create_demo_dataset(seed=9999, N=20): +# import numpy as np +# import pandas as pd +# from scipy.stats import norm # Used in generation of populations. + +# np.random.seed(9999) # Fix the seed so the results are replicable. +# # pop_size = 10000 # Size of each population. + +# # Create samples +# c1 = norm.rvs(loc=3, scale=0.4, size=N) +# c2 = norm.rvs(loc=3.5, scale=0.75, size=N) +# c3 = norm.rvs(loc=3.25, scale=0.4, size=N) + +# t1 = norm.rvs(loc=3.5, scale=0.5, size=N) +# t2 = norm.rvs(loc=2.5, scale=0.6, size=N) +# t3 = norm.rvs(loc=3, scale=0.75, size=N) +# t4 = norm.rvs(loc=3.5, scale=0.75, size=N) +# t5 = norm.rvs(loc=3.25, scale=0.4, size=N) +# t6 = norm.rvs(loc=3.25, scale=0.4, size=N) + +# # Add a `gender` column for coloring the data. +# females = np.repeat("Female", N / 2).tolist() +# males = np.repeat("Male", N / 2).tolist() +# gender = females + males + +# # Add an `id` column for paired data plotting. +# id_col = pd.Series(range(1, N + 1)) + +# # Combine samples and gender into a DataFrame. +# df = pd.DataFrame( +# { +# "Control 1": c1, +# "Test 1": t1, +# "Control 2": c2, +# "Test 2": t2, +# "Control 3": c3, +# "Test 3": t3, +# "Test 4": t4, +# "Test 5": t5, +# "Test 6": t6, +# "Gender": gender, +# "ID": id_col, +# } +# ) + +# return df + + +# df = create_demo_dataset() + +# two_groups_unpaired = load(df, idx=("Control 1", "Test 1")) + +# two_groups_paired = load( +# df, idx=("Control 1", "Test 1"), paired="baseline", id_col="ID" +# ) + +# two_groups_unpaired.mean_diff.plot()