Skip to content

Commit

Permalink
Merge pull request #175 from ACCLAB/JAnns98-patch-1
Browse files Browse the repository at this point in the history
Update 07-forest_plot.ipynb, solve some color palette issue
  • Loading branch information
Jacobluke- authored Sep 24, 2024
2 parents d1c123f + b8a358f commit f8fa263
Show file tree
Hide file tree
Showing 16 changed files with 330 additions and 1,212 deletions.
4 changes: 2 additions & 2 deletions dabest/misc_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def get_kwargs(plot_kwargs, ytick_color):
delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs)


def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx):
def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_groups):

# Create color palette that will be shared across subplots.
color_col = plot_kwargs["color_col"]
Expand Down Expand Up @@ -350,7 +350,7 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx):
else:
if isinstance(custom_pal, dict):
groups_in_palette = {
k: v for k, v in custom_pal.items() if k in color_groups
k: custom_pal[k] for k in all_plot_groups if k in color_groups
}

names = groups_in_palette.keys()
Expand Down
29 changes: 20 additions & 9 deletions dabest/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def sankeydiag(

def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object,
float_contrast: bool,summary_bars_kwargs: dict, ci_type: str,
ticks_to_plot: list, color_col: str, swarm_colors: list,
ticks_to_plot: list, color_col: str, plot_palette_raw: dict,
proportional: bool, is_paired: bool):
"""
Add summary bars to the contrast plot.
Expand All @@ -843,8 +843,8 @@ def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object
List of indices of the contrast objects.
color_col : str
Column name of the color column.
swarm_colors : list
List of colors used in the plot.
plot_palette_raw : dict
Dictionary of colors used in the plot.
proportional : bool
Whether the data is proportional.
is_paired : bool
Expand All @@ -862,7 +862,13 @@ def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object
# End checks
else:
summary_xmin, summary_xmax = ax_to_plot.get_xlim()
summary_bars_colors = [summary_bars_kwargs.get('color')]*(max(summary_bars)+1) if summary_bars_kwargs.get('color') is not None else ['black']*(max(summary_bars)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors
summary_bars_colors = (
[summary_bars_kwargs.get('color')]*(max(summary_bars)+1)
if summary_bars_kwargs.get('color') is not None
else ['black']*(max(summary_bars)+1)
if color_col is not None or (proportional and is_paired) or is_paired
else list(plot_palette_raw.values())
)
summary_bars_kwargs.pop('color')
for summary_index in summary_bars:
if ci_type == "bca":
Expand Down Expand Up @@ -973,7 +979,6 @@ def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object,
swarm_bars_order = pd.unique(plot_data[xvar])

swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order)
# swarm_bars_colors = [swarm_bars_kwargs.get('color')]*(max(swarm_bars_order)+1) if swarm_bars_kwargs.get('color') is not None else ['black']*(len(swarm_bars_order)+1) if color_col is not None or is_paired else swarm_colors
swarm_bars_colors = (
[swarm_bars_kwargs.get('color')] * (max(swarm_bars_order) + 1)
if swarm_bars_kwargs.get('color') is not None
Expand All @@ -987,7 +992,7 @@ def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object,
0.5, swarm_bars_y, zorder=-1,color=c,**swarm_bars_kwargs))

def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: object, ticks_to_plot: list, delta_text_kwargs: dict, color_col: str,
swarm_colors: list, is_paired: bool, proportional: bool, float_contrast: bool,
plot_palette_raw: dict, is_paired: bool, proportional: bool, float_contrast: bool,
show_mini_meta: bool, mini_meta_delta: object, show_delta2: bool, delta_delta: object):
"""
Add text to the contrast plot.
Expand All @@ -1006,8 +1011,8 @@ def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: objec
Keyword arguments for the delta text.
color_col : str
Column name of the color column.
swarm_colors : list
List of colors used in the plot.
plot_palette_raw : dict
Dictionary of colors used in the plot.
is_paired : bool
Whether the data is paired.
proportional : bool
Expand All @@ -1032,7 +1037,13 @@ def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: objec
delta_text_kwargs["va"] = 'bottom' if results.difference[0] >= 0 else 'top'
delta_text_kwargs.pop('x_location')

delta_text_colors = [delta_text_kwargs.get('color')]*(max(ticks_to_plot)+1) if delta_text_kwargs.get('color') is not None else ['black']*(max(ticks_to_plot)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors
delta_text_colors = (
[delta_text_kwargs.get('color')]*(max(ticks_to_plot)+1)
if delta_text_kwargs.get('color') is not None
else ['black']*(max(ticks_to_plot)+1)
if color_col is not None or (proportional and is_paired) or is_paired
else list(plot_palette_raw.values())
)
if show_mini_meta or show_delta2: delta_text_colors.append('black')
delta_text_kwargs.pop('color')

Expand Down
7 changes: 4 additions & 3 deletions dabest/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
plot_data=plot_data,
xvar=xvar,
show_pairs=show_pairs,
idx=idx
idx=idx,
all_plot_groups=all_plot_groups
)

# Initialise the figure.
Expand Down Expand Up @@ -551,7 +552,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
ci_type=ci_type,
ticks_to_plot=ticks_to_plot,
color_col=color_col,
swarm_colors=swarm_colors,
plot_palette_raw=plot_palette_raw,
proportional=proportional,
is_paired=is_paired
)
Expand All @@ -565,7 +566,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
ticks_to_plot=ticks_to_plot,
delta_text_kwargs=delta_text_kwargs,
color_col=color_col,
swarm_colors=swarm_colors,
plot_palette_raw=plot_palette_raw,
is_paired=is_paired,
proportional=proportional,
float_contrast=float_contrast,
Expand Down
4 changes: 2 additions & 2 deletions nbs/API/misc_tools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@
" delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs)\n",
"\n",
"\n",
"def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx):\n",
"def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_groups):\n",
"\n",
" # Create color palette that will be shared across subplots.\n",
" color_col = plot_kwargs[\"color_col\"]\n",
Expand Down Expand Up @@ -403,7 +403,7 @@
" else:\n",
" if isinstance(custom_pal, dict):\n",
" groups_in_palette = {\n",
" k: v for k, v in custom_pal.items() if k in color_groups\n",
" k: custom_pal[k] for k in all_plot_groups if k in color_groups\n",
" }\n",
"\n",
" names = groups_in_palette.keys()\n",
Expand Down
29 changes: 20 additions & 9 deletions nbs/API/plot_tools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@
"\n",
"def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object,\n",
" float_contrast: bool,summary_bars_kwargs: dict, ci_type: str,\n",
" ticks_to_plot: list, color_col: str, swarm_colors: list, \n",
" ticks_to_plot: list, color_col: str, plot_palette_raw: dict, \n",
" proportional: bool, is_paired: bool):\n",
" \"\"\"\n",
" Add summary bars to the contrast plot.\n",
Expand All @@ -894,8 +894,8 @@
" List of indices of the contrast objects.\n",
" color_col : str\n",
" Column name of the color column.\n",
" swarm_colors : list\n",
" List of colors used in the plot.\n",
" plot_palette_raw : dict\n",
" Dictionary of colors used in the plot.\n",
" proportional : bool\n",
" Whether the data is proportional.\n",
" is_paired : bool\n",
Expand All @@ -913,7 +913,13 @@
"# End checks\n",
" else:\n",
" summary_xmin, summary_xmax = ax_to_plot.get_xlim()\n",
" summary_bars_colors = [summary_bars_kwargs.get('color')]*(max(summary_bars)+1) if summary_bars_kwargs.get('color') is not None else ['black']*(max(summary_bars)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors\n",
" summary_bars_colors = (\n",
" [summary_bars_kwargs.get('color')]*(max(summary_bars)+1)\n",
" if summary_bars_kwargs.get('color') is not None\n",
" else ['black']*(max(summary_bars)+1)\n",
" if color_col is not None or (proportional and is_paired) or is_paired \n",
" else list(plot_palette_raw.values())\n",
" )\n",
" summary_bars_kwargs.pop('color')\n",
" for summary_index in summary_bars:\n",
" if ci_type == \"bca\":\n",
Expand Down Expand Up @@ -1024,7 +1030,6 @@
" swarm_bars_order = pd.unique(plot_data[xvar])\n",
"\n",
" swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order)\n",
" # swarm_bars_colors = [swarm_bars_kwargs.get('color')]*(max(swarm_bars_order)+1) if swarm_bars_kwargs.get('color') is not None else ['black']*(len(swarm_bars_order)+1) if color_col is not None or is_paired else swarm_colors\n",
" swarm_bars_colors = (\n",
" [swarm_bars_kwargs.get('color')] * (max(swarm_bars_order) + 1) \n",
" if swarm_bars_kwargs.get('color') is not None \n",
Expand All @@ -1038,7 +1043,7 @@
" 0.5, swarm_bars_y, zorder=-1,color=c,**swarm_bars_kwargs))\n",
"\n",
"def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: object, ticks_to_plot: list, delta_text_kwargs: dict, color_col: str, \n",
" swarm_colors: list, is_paired: bool, proportional: bool, float_contrast: bool,\n",
" plot_palette_raw: dict, is_paired: bool, proportional: bool, float_contrast: bool,\n",
" show_mini_meta: bool, mini_meta_delta: object, show_delta2: bool, delta_delta: object):\n",
" \"\"\"\n",
" Add text to the contrast plot.\n",
Expand All @@ -1057,8 +1062,8 @@
" Keyword arguments for the delta text.\n",
" color_col : str\n",
" Column name of the color column.\n",
" swarm_colors : list\n",
" List of colors used in the plot.\n",
" plot_palette_raw : dict\n",
" Dictionary of colors used in the plot.\n",
" is_paired : bool\n",
" Whether the data is paired.\n",
" proportional : bool\n",
Expand All @@ -1083,7 +1088,13 @@
" delta_text_kwargs[\"va\"] = 'bottom' if results.difference[0] >= 0 else 'top'\n",
" delta_text_kwargs.pop('x_location')\n",
"\n",
" delta_text_colors = [delta_text_kwargs.get('color')]*(max(ticks_to_plot)+1) if delta_text_kwargs.get('color') is not None else ['black']*(max(ticks_to_plot)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors\n",
" delta_text_colors = (\n",
" [delta_text_kwargs.get('color')]*(max(ticks_to_plot)+1)\n",
" if delta_text_kwargs.get('color') is not None\n",
" else ['black']*(max(ticks_to_plot)+1)\n",
" if color_col is not None or (proportional and is_paired) or is_paired\n",
" else list(plot_palette_raw.values())\n",
" )\n",
" if show_mini_meta or show_delta2: delta_text_colors.append('black')\n",
" delta_text_kwargs.pop('color')\n",
"\n",
Expand Down
7 changes: 4 additions & 3 deletions nbs/API/plotter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@
" plot_data=plot_data, \n",
" xvar=xvar, \n",
" show_pairs=show_pairs,\n",
" idx=idx\n",
" idx=idx,\n",
" all_plot_groups=all_plot_groups\n",
" )\n",
"\n",
" # Initialise the figure.\n",
Expand Down Expand Up @@ -608,7 +609,7 @@
" ci_type=ci_type, \n",
" ticks_to_plot=ticks_to_plot, \n",
" color_col=color_col,\n",
" swarm_colors=swarm_colors, \n",
" plot_palette_raw=plot_palette_raw, \n",
" proportional=proportional, \n",
" is_paired=is_paired\n",
" )\n",
Expand All @@ -622,7 +623,7 @@
" ticks_to_plot=ticks_to_plot, \n",
" delta_text_kwargs=delta_text_kwargs, \n",
" color_col=color_col, \n",
" swarm_colors=swarm_colors, \n",
" plot_palette_raw=plot_palette_raw, \n",
" is_paired=is_paired,\n",
" proportional=proportional, \n",
" float_contrast=float_contrast, \n",
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 35 additions & 0 deletions nbs/tests/mpl_image_tests/test_plot_aesthetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
mpl.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as Ticker
import seaborn as sns

from dabest._api import load

Expand Down Expand Up @@ -146,6 +147,32 @@ def create_demo_dataset_delta(seed=9999, N=20):
unpaired_mini_meta = load(df, idx=(("Control 1", "Test 1"), ("Control 2", "Test 2"), ("Control 3", "Test 3")),
mini_meta=True)

multi_groups_change_idx_original = load(
df,
idx=(
("Control 1", "Test 1", "Test 2"),
("Control 2", "Test 3", "Test 4"),
("Control 3", "Test 5", "Test 6"),
),
)
multi_groups_change_idx_new = load(
df,
idx=(
("Control 1", "Control 2", "Control 3"),
("Test 1", "Test 3", "Test 5"),
("Test 2", "Test 4", "Test 6"),
),
)
palette = {"Control 1": sns.color_palette("magma")[5],
"Test 1": sns.color_palette("magma")[3],
"Test 2": sns.color_palette("magma")[1],
"Control 2": sns.color_palette("magma")[5],
"Test 3": sns.color_palette("magma")[3],
"Test 4": sns.color_palette("magma")[1],
"Control 3": sns.color_palette("magma")[5],
"Test 5": sns.color_palette("magma")[3],
"Test 6": sns.color_palette("magma")[1]}

@pytest.mark.mpl_image_compare(tolerance=8)
def test_207_gardner_altman_meandiff_empty_circle():
return two_groups_unpaired.mean_diff.plot(empty_circle=True);
Expand Down Expand Up @@ -173,3 +200,11 @@ def test_212_cummings_unpaired_delta_delta_meandiff_empty_circle():
@pytest.mark.mpl_image_compare(tolerance=8)
def test_213_cummings_unpaired_mini_meta_meandiff_empty_circle():
return unpaired_mini_meta.mean_diff.plot(empty_circle=True);

@pytest.mark.mpl_image_compare(tolerance=8)
def test_214_change_idx_order_custom_palette_original():
return multi_groups_change_idx_original.mean_diff.plot(custom_palette=palette);

@pytest.mark.mpl_image_compare(tolerance=8)
def test_215_change_idx_order_custom_palette_new():
return multi_groups_change_idx_new.mean_diff.plot(custom_palette=palette);
Loading

0 comments on commit f8fa263

Please sign in to comment.