Skip to content

Commit

Permalink
Preliminary changes to plotter and scatterplot functions to allow leg…
Browse files Browse the repository at this point in the history
…end plotting for swarmplots with color_col=True (not delta2/minimeta)
  • Loading branch information
JAnns98 committed Sep 13, 2024
1 parent ddeefa7 commit 2df0aff
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 36 deletions.
7 changes: 6 additions & 1 deletion dabest/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,10 @@ def plot(
cmap = []
for cmap_group_i in cmap_values:
cmap.append(self.__palette[cmap_group_i])

# WIP: legend for swarm plot
swarm_legend_kwargs = {'colors':cmap, 'labels':cmap_values, 'index':index}

cmap = ListedColormap(cmap)
ax.scatter(
values_i["x_new"],
Expand All @@ -2000,6 +2004,7 @@ def plot(
edgecolor="face",
**kwargs,
)

else:
# color swarms based on `x` column
ax.scatter(
Expand All @@ -2015,4 +2020,4 @@ def plot(
ax.get_xaxis().set_ticks(np.arange(x_position))
ax.get_xaxis().set_ticklabels(x_tick_tabels)

return ax
return ax, swarm_legend_kwargs if self.__hue is not None else None
45 changes: 28 additions & 17 deletions dabest/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
import pandas as pd
import warnings
import logging
Expand Down Expand Up @@ -219,21 +220,21 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):

# swarmplot() plots swarms based on current size of ax
# Therefore, since the ax size for mini_meta and show_delta changes later on, there has to be increased jitter
rawdata_plot = swarmplot(
data=plot_data,
x=xvar,
y=yvar,
ax=rawdata_axes,
order=all_plot_groups,
hue=xvar if color_col is None else color_col,
palette=plot_palette_raw,
zorder=1,
side=asymmetric_side,
jitter=1.25 if show_mini_meta else 1.4 if show_delta2 else 1, # TODO: to make jitter value more accurate and not just a hardcoded eyeball value
is_drop_gutter=True,
gutter_limit=0.45,
**swarmplot_kwargs
)
rawdata_plot, swarm_legend_kwargs = swarmplot(
data=plot_data,
x=xvar,
y=yvar,
ax=rawdata_axes,
order=all_plot_groups,
hue=xvar if color_col is None else color_col,
palette=plot_palette_raw,
zorder=1,
side=asymmetric_side,
jitter=1.25 if show_mini_meta else 1.4 if show_delta2 else 1, # TODO: to make jitter value more accurate and not just a hardcoded eyeball value
is_drop_gutter=True,
gutter_limit=0.45,
**swarmplot_kwargs
)
if color_col is None:
rawdata_plot.legend().set_visible(False)

Expand Down Expand Up @@ -384,10 +385,9 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
handles, labels = rawdata_axes.get_legend_handles_labels()
legend_labels = [l for l in labels]
legend_handles = [h for h in handles]
if bootstraps_color_by_group is False:
rawdata_axes.legend().set_visible(False)

if bootstraps_color_by_group is False:
rawdata_axes.legend().set_visible(False)
show_legend(
legend_labels=legend_labels,
legend_handles=legend_handles,
Expand All @@ -398,6 +398,17 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
legend_kwargs=legend_kwargs
)

########## WIP LEGENDS
if not show_pairs and not proportional and color_col is not None and not show_delta2:
if len(np.unique(swarm_legend_kwargs['index'])) > 1:
legend_elements = []
for color, label in zip(swarm_legend_kwargs['colors'], swarm_legend_kwargs['labels']):
legend_elements.append(Line2D([0], [0], marker='o', color='w', label=label,
markerfacecolor=color, markersize=10))
rawdata_axes.legend(handles=legend_elements, frameon=False)

########## WIP LEGENDS

# Plot aesthetic adjustments.
og_ylim_raw = rawdata_axes.get_ylim()
og_xlim_raw = rawdata_axes.get_xlim()
Expand Down
15 changes: 14 additions & 1 deletion nbs/API/plot_tools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2050,6 +2050,10 @@
" cmap = []\n",
" for cmap_group_i in cmap_values:\n",
" cmap.append(self.__palette[cmap_group_i])\n",
"\n",
" # WIP: legend for swarm plot\n",
" swarm_legend_kwargs = {'colors':cmap, 'labels':cmap_values, 'index':index}\n",
"\n",
" cmap = ListedColormap(cmap)\n",
" ax.scatter(\n",
" values_i[\"x_new\"],\n",
Expand All @@ -2061,6 +2065,7 @@
" edgecolor=\"face\",\n",
" **kwargs,\n",
" )\n",
"\n",
" else:\n",
" # color swarms based on `x` column\n",
" ax.scatter(\n",
Expand All @@ -2076,8 +2081,16 @@
" ax.get_xaxis().set_ticks(np.arange(x_position))\n",
" ax.get_xaxis().set_ticklabels(x_tick_tabels)\n",
"\n",
" return ax"
" return ax, swarm_legend_kwargs if self.__hue is not None else None"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "022ea903",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
45 changes: 28 additions & 17 deletions nbs/API/plotter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.patches as mpatches\n",
"from matplotlib.lines import Line2D\n",
"import pandas as pd\n",
"import warnings\n",
"import logging"
Expand Down Expand Up @@ -278,21 +279,21 @@
"\n",
" # swarmplot() plots swarms based on current size of ax\n",
" # Therefore, since the ax size for mini_meta and show_delta changes later on, there has to be increased jitter\n",
" rawdata_plot = swarmplot(\n",
" data=plot_data,\n",
" x=xvar,\n",
" y=yvar,\n",
" ax=rawdata_axes,\n",
" order=all_plot_groups,\n",
" hue=xvar if color_col is None else color_col,\n",
" palette=plot_palette_raw,\n",
" zorder=1,\n",
" side=asymmetric_side,\n",
" jitter=1.25 if show_mini_meta else 1.4 if show_delta2 else 1, # TODO: to make jitter value more accurate and not just a hardcoded eyeball value\n",
" is_drop_gutter=True,\n",
" gutter_limit=0.45,\n",
" **swarmplot_kwargs\n",
" )\n",
" rawdata_plot, swarm_legend_kwargs = swarmplot(\n",
" data=plot_data,\n",
" x=xvar,\n",
" y=yvar,\n",
" ax=rawdata_axes,\n",
" order=all_plot_groups,\n",
" hue=xvar if color_col is None else color_col,\n",
" palette=plot_palette_raw,\n",
" zorder=1,\n",
" side=asymmetric_side,\n",
" jitter=1.25 if show_mini_meta else 1.4 if show_delta2 else 1, # TODO: to make jitter value more accurate and not just a hardcoded eyeball value\n",
" is_drop_gutter=True,\n",
" gutter_limit=0.45,\n",
" **swarmplot_kwargs\n",
" )\n",
" if color_col is None:\n",
" rawdata_plot.legend().set_visible(False)\n",
"\n",
Expand Down Expand Up @@ -443,10 +444,9 @@
" handles, labels = rawdata_axes.get_legend_handles_labels()\n",
" legend_labels = [l for l in labels]\n",
" legend_handles = [h for h in handles]\n",
" if bootstraps_color_by_group is False:\n",
" rawdata_axes.legend().set_visible(False)\n",
"\n",
" if bootstraps_color_by_group is False:\n",
" rawdata_axes.legend().set_visible(False)\n",
" show_legend(\n",
" legend_labels=legend_labels, \n",
" legend_handles=legend_handles, \n",
Expand All @@ -457,6 +457,17 @@
" legend_kwargs=legend_kwargs\n",
" )\n",
"\n",
" ########## WIP LEGENDS\n",
" if not show_pairs and not proportional and color_col is not None and not show_delta2:\n",
" if len(np.unique(swarm_legend_kwargs['index'])) > 1:\n",
" legend_elements = []\n",
" for color, label in zip(swarm_legend_kwargs['colors'], swarm_legend_kwargs['labels']):\n",
" legend_elements.append(Line2D([0], [0], marker='o', color='w', label=label,\n",
" markerfacecolor=color, markersize=10))\n",
" rawdata_axes.legend(handles=legend_elements, frameon=False)\n",
"\n",
" ########## WIP LEGENDS\n",
"\n",
" # Plot aesthetic adjustments.\n",
" og_ylim_raw = rawdata_axes.get_ylim()\n",
" og_xlim_raw = rawdata_axes.get_xlim()\n",
Expand Down

0 comments on commit 2df0aff

Please sign in to comment.