Skip to content

Commit

Permalink
Sankey functionality added to horizontal plots
Browse files Browse the repository at this point in the history
  • Loading branch information
JAnns98 committed Sep 28, 2024
1 parent f893736 commit 7ab566e
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 48 deletions.
84 changes: 60 additions & 24 deletions dabest/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def single_sankey(
one_sankey: bool = False, # if True, only draw one sankey diagram
right_color: bool = False, # if True, each strip of the diagram will be colored according to the corresponding left labels
align: str = "center", # if 'center', the diagram will be centered on each xtick, if 'edge', the diagram will be aligned with the left edge of each xtick
horizontal: bool = False, # if True, the horizontal format for the sankey diagram will be used
):
"""
Make a single Sankey diagram showing proportion flow from left to right
Expand Down Expand Up @@ -542,16 +543,24 @@ def single_sankey(

# Plot vertical bars for each label
for left_label in left_labels:
ax.fill_between(
[leftpos + (-(bar_width) * xMax * 0.5), leftpos + (bar_width * xMax * 0.5)],
2 * [leftWidths_norm[left_label]["bottom"]],
2 * [leftWidths_norm[left_label]["top"]],
color=colorDict[left_label],
alpha=0.99,
)
if horizontal:
fill_method = ax.fill_betweenx
else:
fill_method = ax.fill_between
fill_method(
[leftpos + (-(bar_width) * xMax * 0.5), leftpos + (bar_width * xMax * 0.5)],
2 * [leftWidths_norm[left_label]["bottom"]],
2 * [leftWidths_norm[left_label]["top"]],
color=colorDict[left_label],
alpha=0.99,
)
if (not flow and sankey) or one_sankey:
for right_label in right_labels:
ax.fill_between(
if horizontal:
fill_method = ax.fill_betweenx
else:
fill_method = ax.fill_between
fill_method(
[
xMax + leftpos + (-bar_width * xMax * 0.5),
leftpos + xMax + (bar_width * xMax * 0.5),
Expand All @@ -564,16 +573,29 @@ def single_sankey(

# Plot error bars
if error_bar_on and strip_on:
error_bar(
concatenated_df,
x="groups",
y="values",
ax=ax,
offset=0,
gap_width_percent=2,
method="sankey_error_bar",
pos=[leftpos, leftpos + xMax],
)
if horizontal:
error_bar(
concatenated_df,
x="groups",
y="values",
ax=ax,
offset=0,
gap_width_percent=2,
method="sankey_error_bar",
pos=[leftpos, leftpos + xMax],
horizontal=True,
)
else:
error_bar(
concatenated_df,
x="groups",
y="values",
ax=ax,
offset=0,
gap_width_percent=2,
method="sankey_error_bar",
pos=[leftpos, leftpos + xMax],
)

# Determine widths of individual strips, all widths are normalized to 1
ns_l = defaultdict()
Expand Down Expand Up @@ -631,7 +653,11 @@ def single_sankey(
rightWidths_norm[right_label]["bottom"] += ns_r_norm[left_label][
right_label
]
ax.fill_between(
if horizontal:
fill_method = ax.fill_betweenx
else:
fill_method = ax.fill_between
fill_method(
np.linspace(
leftpos + (bar_width * xMax * 0.5),
leftpos + xMax - (bar_width * xMax * 0.5),
Expand All @@ -644,7 +670,6 @@ def single_sankey(
edgecolor="none",
)


def sankeydiag(
data: pd.DataFrame,
xvar: str, # x column to be plotted.
Expand All @@ -663,6 +688,7 @@ def sankeydiag(
right_color: bool = False, # if True, each strip of the diagram will be colored according to the corresponding left labels
align: str = "center", # the alignment of each sankey diagram, can be 'center' or 'left'
alpha: float = 0.65, # the transparency of each strip
horizontal: bool = False, # if True, the horizontal format for the sankey diagram will be used
**kwargs,
):
"""
Expand Down Expand Up @@ -790,6 +816,7 @@ def sankeydiag(
flow=flow,
align=align,
alpha=alpha,
horizontal=horizontal,
)
xpos += 1
else:
Expand All @@ -813,6 +840,7 @@ def sankeydiag(
flow=False,
align="edge",
alpha=alpha,
horizontal=horizontal,
)

# Now only draw vs xticks for two-column sankey diagram
Expand All @@ -825,12 +853,20 @@ def sankeydiag(
for left, right in zip(broadcasted_left, right_idx)
]
)
ax.get_xaxis().set_ticks(np.arange(len(right_idx)))
ax.get_xaxis().set_ticklabels(sankey_ticks)
if horizontal:
ax.get_yaxis().set_ticks(np.arange(len(right_idx)))
ax.get_yaxis().set_ticklabels(sankey_ticks)
else:
ax.get_xaxis().set_ticks(np.arange(len(right_idx)))
ax.get_xaxis().set_ticklabels(sankey_ticks)
else:
sankey_ticks = [broadcasted_left[0], right_idx[0]]
ax.set_xticks([0, 1])
ax.set_xticklabels(sankey_ticks)
if horizontal:
ax.set_yticks([0, 1])
ax.set_yticklabels(sankey_ticks)
else:
ax.set_xticks([0, 1])
ax.set_xticklabels(sankey_ticks)

return left_idx, right_idx

Expand Down
1 change: 1 addition & 0 deletions dabest/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,7 @@ def effectsize_df_plotter_horizontal(effectsize_df, **plot_kwargs):
temp_idx=temp_idx,
palette=plot_palette_sankey,
ax=rawdata_axes,
horizontal=True,
**sankey_kwargs
)

Expand Down
84 changes: 60 additions & 24 deletions nbs/API/plot_tools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@
" one_sankey: bool = False, # if True, only draw one sankey diagram\n",
" right_color: bool = False, # if True, each strip of the diagram will be colored according to the corresponding left labels\n",
" align: str = \"center\", # if 'center', the diagram will be centered on each xtick, if 'edge', the diagram will be aligned with the left edge of each xtick\n",
" horizontal: bool = False, # if True, the horizontal format for the sankey diagram will be used\n",
"):\n",
" \"\"\"\n",
" Make a single Sankey diagram showing proportion flow from left to right\n",
Expand Down Expand Up @@ -594,16 +595,24 @@
"\n",
" # Plot vertical bars for each label\n",
" for left_label in left_labels:\n",
" ax.fill_between(\n",
" [leftpos + (-(bar_width) * xMax * 0.5), leftpos + (bar_width * xMax * 0.5)],\n",
" 2 * [leftWidths_norm[left_label][\"bottom\"]],\n",
" 2 * [leftWidths_norm[left_label][\"top\"]],\n",
" color=colorDict[left_label],\n",
" alpha=0.99,\n",
" )\n",
" if horizontal:\n",
" fill_method = ax.fill_betweenx\n",
" else:\n",
" fill_method = ax.fill_between\n",
" fill_method(\n",
" [leftpos + (-(bar_width) * xMax * 0.5), leftpos + (bar_width * xMax * 0.5)],\n",
" 2 * [leftWidths_norm[left_label][\"bottom\"]],\n",
" 2 * [leftWidths_norm[left_label][\"top\"]],\n",
" color=colorDict[left_label],\n",
" alpha=0.99,\n",
" )\n",
" if (not flow and sankey) or one_sankey:\n",
" for right_label in right_labels:\n",
" ax.fill_between(\n",
" if horizontal:\n",
" fill_method = ax.fill_betweenx\n",
" else:\n",
" fill_method = ax.fill_between\n",
" fill_method(\n",
" [\n",
" xMax + leftpos + (-bar_width * xMax * 0.5),\n",
" leftpos + xMax + (bar_width * xMax * 0.5),\n",
Expand All @@ -616,16 +625,29 @@
"\n",
" # Plot error bars\n",
" if error_bar_on and strip_on:\n",
" error_bar(\n",
" concatenated_df,\n",
" x=\"groups\",\n",
" y=\"values\",\n",
" ax=ax,\n",
" offset=0,\n",
" gap_width_percent=2,\n",
" method=\"sankey_error_bar\",\n",
" pos=[leftpos, leftpos + xMax],\n",
" )\n",
" if horizontal:\n",
" error_bar(\n",
" concatenated_df,\n",
" x=\"groups\",\n",
" y=\"values\",\n",
" ax=ax,\n",
" offset=0,\n",
" gap_width_percent=2,\n",
" method=\"sankey_error_bar\",\n",
" pos=[leftpos, leftpos + xMax],\n",
" horizontal=True,\n",
" )\n",
" else:\n",
" error_bar(\n",
" concatenated_df,\n",
" x=\"groups\",\n",
" y=\"values\",\n",
" ax=ax,\n",
" offset=0,\n",
" gap_width_percent=2,\n",
" method=\"sankey_error_bar\",\n",
" pos=[leftpos, leftpos + xMax],\n",
" )\n",
"\n",
" # Determine widths of individual strips, all widths are normalized to 1\n",
" ns_l = defaultdict()\n",
Expand Down Expand Up @@ -683,7 +705,11 @@
" rightWidths_norm[right_label][\"bottom\"] += ns_r_norm[left_label][\n",
" right_label\n",
" ]\n",
" ax.fill_between(\n",
" if horizontal:\n",
" fill_method = ax.fill_betweenx\n",
" else:\n",
" fill_method = ax.fill_between\n",
" fill_method(\n",
" np.linspace(\n",
" leftpos + (bar_width * xMax * 0.5),\n",
" leftpos + xMax - (bar_width * xMax * 0.5),\n",
Expand All @@ -696,7 +722,6 @@
" edgecolor=\"none\",\n",
" )\n",
"\n",
"\n",
"def sankeydiag(\n",
" data: pd.DataFrame,\n",
" xvar: str, # x column to be plotted.\n",
Expand All @@ -715,6 +740,7 @@
" right_color: bool = False, # if True, each strip of the diagram will be colored according to the corresponding left labels\n",
" align: str = \"center\", # the alignment of each sankey diagram, can be 'center' or 'left'\n",
" alpha: float = 0.65, # the transparency of each strip\n",
" horizontal: bool = False, # if True, the horizontal format for the sankey diagram will be used\n",
" **kwargs,\n",
"):\n",
" \"\"\"\n",
Expand Down Expand Up @@ -842,6 +868,7 @@
" flow=flow,\n",
" align=align,\n",
" alpha=alpha,\n",
" horizontal=horizontal,\n",
" )\n",
" xpos += 1\n",
" else:\n",
Expand All @@ -865,6 +892,7 @@
" flow=False,\n",
" align=\"edge\",\n",
" alpha=alpha,\n",
" horizontal=horizontal,\n",
" )\n",
"\n",
" # Now only draw vs xticks for two-column sankey diagram\n",
Expand All @@ -877,12 +905,20 @@
" for left, right in zip(broadcasted_left, right_idx)\n",
" ]\n",
" )\n",
" ax.get_xaxis().set_ticks(np.arange(len(right_idx)))\n",
" ax.get_xaxis().set_ticklabels(sankey_ticks)\n",
" if horizontal:\n",
" ax.get_yaxis().set_ticks(np.arange(len(right_idx)))\n",
" ax.get_yaxis().set_ticklabels(sankey_ticks)\n",
" else:\n",
" ax.get_xaxis().set_ticks(np.arange(len(right_idx)))\n",
" ax.get_xaxis().set_ticklabels(sankey_ticks)\n",
" else:\n",
" sankey_ticks = [broadcasted_left[0], right_idx[0]]\n",
" ax.set_xticks([0, 1])\n",
" ax.set_xticklabels(sankey_ticks)\n",
" if horizontal:\n",
" ax.set_yticks([0, 1])\n",
" ax.set_yticklabels(sankey_ticks) \n",
" else:\n",
" ax.set_xticks([0, 1])\n",
" ax.set_xticklabels(sankey_ticks)\n",
"\n",
" return left_idx, right_idx\n",
"\n",
Expand Down
1 change: 1 addition & 0 deletions nbs/API/plotter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,7 @@
" temp_idx=temp_idx,\n",
" palette=plot_palette_sankey,\n",
" ax=rawdata_axes,\n",
" horizontal=True,\n",
" **sankey_kwargs\n",
" )\n",
"\n",
Expand Down

0 comments on commit 7ab566e

Please sign in to comment.