From 7ab566eeee7e1f914f4639506eb1bac6d5c89a75 Mon Sep 17 00:00:00 2001 From: JAnns98 Date: Sat, 28 Sep 2024 23:08:44 +0800 Subject: [PATCH] Sankey functionality added to horizontal plots --- dabest/plot_tools.py | 84 ++++++++++++++++++++++++++++------------ dabest/plotter.py | 1 + nbs/API/plot_tools.ipynb | 84 ++++++++++++++++++++++++++++------------ nbs/API/plotter.ipynb | 1 + 4 files changed, 122 insertions(+), 48 deletions(-) diff --git a/dabest/plot_tools.py b/dabest/plot_tools.py index 86843107..98ce3bee 100644 --- a/dabest/plot_tools.py +++ b/dabest/plot_tools.py @@ -382,6 +382,7 @@ def single_sankey( one_sankey: bool = False, # if True, only draw one sankey diagram right_color: bool = False, # if True, each strip of the diagram will be colored according to the corresponding left labels align: str = "center", # if 'center', the diagram will be centered on each xtick, if 'edge', the diagram will be aligned with the left edge of each xtick + horizontal: bool = False, # if True, the horizontal format for the sankey diagram will be used ): """ Make a single Sankey diagram showing proportion flow from left to right @@ -542,16 +543,24 @@ def single_sankey( # Plot vertical bars for each label for left_label in left_labels: - ax.fill_between( - [leftpos + (-(bar_width) * xMax * 0.5), leftpos + (bar_width * xMax * 0.5)], - 2 * [leftWidths_norm[left_label]["bottom"]], - 2 * [leftWidths_norm[left_label]["top"]], - color=colorDict[left_label], - alpha=0.99, - ) + if horizontal: + fill_method = ax.fill_betweenx + else: + fill_method = ax.fill_between + fill_method( + [leftpos + (-(bar_width) * xMax * 0.5), leftpos + (bar_width * xMax * 0.5)], + 2 * [leftWidths_norm[left_label]["bottom"]], + 2 * [leftWidths_norm[left_label]["top"]], + color=colorDict[left_label], + alpha=0.99, + ) if (not flow and sankey) or one_sankey: for right_label in right_labels: - ax.fill_between( + if horizontal: + fill_method = ax.fill_betweenx + else: + fill_method = ax.fill_between + fill_method( [ xMax + leftpos + (-bar_width * xMax * 0.5), leftpos + xMax + (bar_width * xMax * 0.5), @@ -564,16 +573,29 @@ def single_sankey( # Plot error bars if error_bar_on and strip_on: - error_bar( - concatenated_df, - x="groups", - y="values", - ax=ax, - offset=0, - gap_width_percent=2, - method="sankey_error_bar", - pos=[leftpos, leftpos + xMax], - ) + if horizontal: + error_bar( + concatenated_df, + x="groups", + y="values", + ax=ax, + offset=0, + gap_width_percent=2, + method="sankey_error_bar", + pos=[leftpos, leftpos + xMax], + horizontal=True, + ) + else: + error_bar( + concatenated_df, + x="groups", + y="values", + ax=ax, + offset=0, + gap_width_percent=2, + method="sankey_error_bar", + pos=[leftpos, leftpos + xMax], + ) # Determine widths of individual strips, all widths are normalized to 1 ns_l = defaultdict() @@ -631,7 +653,11 @@ def single_sankey( rightWidths_norm[right_label]["bottom"] += ns_r_norm[left_label][ right_label ] - ax.fill_between( + if horizontal: + fill_method = ax.fill_betweenx + else: + fill_method = ax.fill_between + fill_method( np.linspace( leftpos + (bar_width * xMax * 0.5), leftpos + xMax - (bar_width * xMax * 0.5), @@ -644,7 +670,6 @@ def single_sankey( edgecolor="none", ) - def sankeydiag( data: pd.DataFrame, xvar: str, # x column to be plotted. @@ -663,6 +688,7 @@ def sankeydiag( right_color: bool = False, # if True, each strip of the diagram will be colored according to the corresponding left labels align: str = "center", # the alignment of each sankey diagram, can be 'center' or 'left' alpha: float = 0.65, # the transparency of each strip + horizontal: bool = False, # if True, the horizontal format for the sankey diagram will be used **kwargs, ): """ @@ -790,6 +816,7 @@ def sankeydiag( flow=flow, align=align, alpha=alpha, + horizontal=horizontal, ) xpos += 1 else: @@ -813,6 +840,7 @@ def sankeydiag( flow=False, align="edge", alpha=alpha, + horizontal=horizontal, ) # Now only draw vs xticks for two-column sankey diagram @@ -825,12 +853,20 @@ def sankeydiag( for left, right in zip(broadcasted_left, right_idx) ] ) - ax.get_xaxis().set_ticks(np.arange(len(right_idx))) - ax.get_xaxis().set_ticklabels(sankey_ticks) + if horizontal: + ax.get_yaxis().set_ticks(np.arange(len(right_idx))) + ax.get_yaxis().set_ticklabels(sankey_ticks) + else: + ax.get_xaxis().set_ticks(np.arange(len(right_idx))) + ax.get_xaxis().set_ticklabels(sankey_ticks) else: sankey_ticks = [broadcasted_left[0], right_idx[0]] - ax.set_xticks([0, 1]) - ax.set_xticklabels(sankey_ticks) + if horizontal: + ax.set_yticks([0, 1]) + ax.set_yticklabels(sankey_ticks) + else: + ax.set_xticks([0, 1]) + ax.set_xticklabels(sankey_ticks) return left_idx, right_idx diff --git a/dabest/plotter.py b/dabest/plotter.py index 5970f9f4..19acbd8d 100644 --- a/dabest/plotter.py +++ b/dabest/plotter.py @@ -771,6 +771,7 @@ def effectsize_df_plotter_horizontal(effectsize_df, **plot_kwargs): temp_idx=temp_idx, palette=plot_palette_sankey, ax=rawdata_axes, + horizontal=True, **sankey_kwargs ) diff --git a/nbs/API/plot_tools.ipynb b/nbs/API/plot_tools.ipynb index 6125277f..2cfbe96b 100644 --- a/nbs/API/plot_tools.ipynb +++ b/nbs/API/plot_tools.ipynb @@ -434,6 +434,7 @@ " one_sankey: bool = False, # if True, only draw one sankey diagram\n", " right_color: bool = False, # if True, each strip of the diagram will be colored according to the corresponding left labels\n", " align: str = \"center\", # if 'center', the diagram will be centered on each xtick, if 'edge', the diagram will be aligned with the left edge of each xtick\n", + " horizontal: bool = False, # if True, the horizontal format for the sankey diagram will be used\n", "):\n", " \"\"\"\n", " Make a single Sankey diagram showing proportion flow from left to right\n", @@ -594,16 +595,24 @@ "\n", " # Plot vertical bars for each label\n", " for left_label in left_labels:\n", - " ax.fill_between(\n", - " [leftpos + (-(bar_width) * xMax * 0.5), leftpos + (bar_width * xMax * 0.5)],\n", - " 2 * [leftWidths_norm[left_label][\"bottom\"]],\n", - " 2 * [leftWidths_norm[left_label][\"top\"]],\n", - " color=colorDict[left_label],\n", - " alpha=0.99,\n", - " )\n", + " if horizontal:\n", + " fill_method = ax.fill_betweenx\n", + " else:\n", + " fill_method = ax.fill_between\n", + " fill_method(\n", + " [leftpos + (-(bar_width) * xMax * 0.5), leftpos + (bar_width * xMax * 0.5)],\n", + " 2 * [leftWidths_norm[left_label][\"bottom\"]],\n", + " 2 * [leftWidths_norm[left_label][\"top\"]],\n", + " color=colorDict[left_label],\n", + " alpha=0.99,\n", + " )\n", " if (not flow and sankey) or one_sankey:\n", " for right_label in right_labels:\n", - " ax.fill_between(\n", + " if horizontal:\n", + " fill_method = ax.fill_betweenx\n", + " else:\n", + " fill_method = ax.fill_between\n", + " fill_method(\n", " [\n", " xMax + leftpos + (-bar_width * xMax * 0.5),\n", " leftpos + xMax + (bar_width * xMax * 0.5),\n", @@ -616,16 +625,29 @@ "\n", " # Plot error bars\n", " if error_bar_on and strip_on:\n", - " error_bar(\n", - " concatenated_df,\n", - " x=\"groups\",\n", - " y=\"values\",\n", - " ax=ax,\n", - " offset=0,\n", - " gap_width_percent=2,\n", - " method=\"sankey_error_bar\",\n", - " pos=[leftpos, leftpos + xMax],\n", - " )\n", + " if horizontal:\n", + " error_bar(\n", + " concatenated_df,\n", + " x=\"groups\",\n", + " y=\"values\",\n", + " ax=ax,\n", + " offset=0,\n", + " gap_width_percent=2,\n", + " method=\"sankey_error_bar\",\n", + " pos=[leftpos, leftpos + xMax],\n", + " horizontal=True,\n", + " )\n", + " else:\n", + " error_bar(\n", + " concatenated_df,\n", + " x=\"groups\",\n", + " y=\"values\",\n", + " ax=ax,\n", + " offset=0,\n", + " gap_width_percent=2,\n", + " method=\"sankey_error_bar\",\n", + " pos=[leftpos, leftpos + xMax],\n", + " )\n", "\n", " # Determine widths of individual strips, all widths are normalized to 1\n", " ns_l = defaultdict()\n", @@ -683,7 +705,11 @@ " rightWidths_norm[right_label][\"bottom\"] += ns_r_norm[left_label][\n", " right_label\n", " ]\n", - " ax.fill_between(\n", + " if horizontal:\n", + " fill_method = ax.fill_betweenx\n", + " else:\n", + " fill_method = ax.fill_between\n", + " fill_method(\n", " np.linspace(\n", " leftpos + (bar_width * xMax * 0.5),\n", " leftpos + xMax - (bar_width * xMax * 0.5),\n", @@ -696,7 +722,6 @@ " edgecolor=\"none\",\n", " )\n", "\n", - "\n", "def sankeydiag(\n", " data: pd.DataFrame,\n", " xvar: str, # x column to be plotted.\n", @@ -715,6 +740,7 @@ " right_color: bool = False, # if True, each strip of the diagram will be colored according to the corresponding left labels\n", " align: str = \"center\", # the alignment of each sankey diagram, can be 'center' or 'left'\n", " alpha: float = 0.65, # the transparency of each strip\n", + " horizontal: bool = False, # if True, the horizontal format for the sankey diagram will be used\n", " **kwargs,\n", "):\n", " \"\"\"\n", @@ -842,6 +868,7 @@ " flow=flow,\n", " align=align,\n", " alpha=alpha,\n", + " horizontal=horizontal,\n", " )\n", " xpos += 1\n", " else:\n", @@ -865,6 +892,7 @@ " flow=False,\n", " align=\"edge\",\n", " alpha=alpha,\n", + " horizontal=horizontal,\n", " )\n", "\n", " # Now only draw vs xticks for two-column sankey diagram\n", @@ -877,12 +905,20 @@ " for left, right in zip(broadcasted_left, right_idx)\n", " ]\n", " )\n", - " ax.get_xaxis().set_ticks(np.arange(len(right_idx)))\n", - " ax.get_xaxis().set_ticklabels(sankey_ticks)\n", + " if horizontal:\n", + " ax.get_yaxis().set_ticks(np.arange(len(right_idx)))\n", + " ax.get_yaxis().set_ticklabels(sankey_ticks)\n", + " else:\n", + " ax.get_xaxis().set_ticks(np.arange(len(right_idx)))\n", + " ax.get_xaxis().set_ticklabels(sankey_ticks)\n", " else:\n", " sankey_ticks = [broadcasted_left[0], right_idx[0]]\n", - " ax.set_xticks([0, 1])\n", - " ax.set_xticklabels(sankey_ticks)\n", + " if horizontal:\n", + " ax.set_yticks([0, 1])\n", + " ax.set_yticklabels(sankey_ticks) \n", + " else:\n", + " ax.set_xticks([0, 1])\n", + " ax.set_xticklabels(sankey_ticks)\n", "\n", " return left_idx, right_idx\n", "\n", diff --git a/nbs/API/plotter.ipynb b/nbs/API/plotter.ipynb index b1dd1b1f..61b70a85 100644 --- a/nbs/API/plotter.ipynb +++ b/nbs/API/plotter.ipynb @@ -878,6 +878,7 @@ " temp_idx=temp_idx,\n", " palette=plot_palette_sankey,\n", " ax=rawdata_axes,\n", + " horizontal=True,\n", " **sankey_kwargs\n", " )\n", "\n",