Skip to content

Commit

Permalink
Option for gradient coloring and alpha in edges (#36)
Browse files Browse the repository at this point in the history
* Include fransua's gradient coloring and alpha edges code
  • Loading branch information
pintergreg authored Nov 6, 2023
1 parent 11a2e8b commit bd0a82f
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 8 deletions.
184 changes: 184 additions & 0 deletions pysankey/gradient_test.ipynb

Large diffs are not rendered by default.

85 changes: 77 additions & 8 deletions pysankey/sankey/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -72,6 +73,8 @@ def sankey(
closePlot: bool = False,
figSize: Optional[Tuple[int, int]] = None,
ax: Optional[Any] = None,
color_gradient: bool = False,
alphaDict: Optional[Dict[Union[str, Tuple[str, str]], float]] = None,
) -> Any:
"""
Make Sankey Diagram showing flow from left-->right
Expand Down Expand Up @@ -128,6 +131,20 @@ def sankey(
rightWidths, topEdge = _get_positions_and_total_widths(
data_frame, rightLabels, "right"
)
# If no alphaDict given, make one
if alphaDict is None:
alphaDict = {}
for _, label in enumerate(all_labels):
alphaDict[label] = 0.65
else:
missing = [label for label in all_labels if label not in alphaDict.keys()]
if missing:
msg = (
"The alphaDict parameter is missing values for the following labels : "
)
msg += ", ".join(missing)
raise ValueError(msg)
LOGGER.debug("The alphadict value are : %s", alphaDict)
# Total vertical extent of diagram
xMax = topEdge / aspect
draw_vertical_bars(
Expand All @@ -152,6 +169,8 @@ def sankey(
rightLabels,
rightWidths,
xMax,
alphaDict,
color_gradient,
)
if figSize is not None:
plt.gcf().set_size_inches(figSize)
Expand Down Expand Up @@ -353,7 +372,10 @@ def _create_dataframe(

def plot_strips(
ax: Any,
colorDict: Union[Dict[str, Tuple[float, float, float]], Dict[str, str]],
colorDict: Union[
Dict[Union[str, Tuple[str, str]], Tuple[float, float, float]],
Dict[Union[str, Tuple[str, str]], str],
],
dataFrame: DataFrame,
leftLabels: ndarray,
leftWidths: Dict,
Expand All @@ -363,6 +385,8 @@ def plot_strips(
rightLabels: ndarray,
rightWidths: Dict,
xMax: float64,
alphaDict: Dict[Union[str, Tuple[str, str]], float],
color_gradient: bool = False,
) -> None:
# Plot strips
for leftLabel in leftLabels:
Expand Down Expand Up @@ -398,13 +422,58 @@ def plot_strips(
# right place
leftWidths[leftLabel]["bottom"] += ns_l[leftLabel][rightLabel]
rightWidths[rightLabel]["bottom"] += ns_r[leftLabel][rightLabel]
ax.fill_between(
np.linspace(0, xMax, len(ys_d)),
ys_d,
ys_u,
alpha=0.65,
color=colorDict[label_color],
)

if (leftLabel, rightLabel) in alphaDict:
alpha = alphaDict[leftLabel, rightLabel]
else:
alpha = alphaDict[label_color]
if color_gradient:
if (leftLabel, rightLabel) in colorDict:
cleft = cright = colorDict[leftLabel, rightLabel]
else:
cleft = colorDict[leftLabel]
cright = colorDict[rightLabel]

x = list(np.linspace(0, xMax, len(ys_d)))
(poly,) = ax.fill(
x + x[::-1] + [x[0]],
list(ys_d) + list(ys_u)[::-1] + [ys_d[0]],
facecolor="none",
)

# get the extent of the axes
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()

# create a dummy image
img_data = np.arange(xmin, xmax, (xmax - xmin) / 100.0)
img_data = img_data.reshape(img_data.size, 1).T

# plot and clip the image
im = ax.imshow(
img_data,
aspect="auto",
origin="lower",
cmap=mpl.colors.LinearSegmentedColormap.from_list(
"custom", [cleft, cright]
),
alpha=alpha,
extent=[xmin, xmax, ymin, ymax],
)

im.set_clip_path(poly)
else:
if (leftLabel, rightLabel) in colorDict:
color = colorDict[leftLabel, rightLabel]
else:
color = colorDict[label_color]
ax.fill_between(
np.linspace(0, xMax, len(ys_d)),
ys_d,
ys_u,
alpha=alpha,
color=color,
)
ax.axis("off")


Expand Down

0 comments on commit bd0a82f

Please sign in to comment.