Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add cmasher.combine_cmaps #122

Merged
merged 1 commit into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions cmasher/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import cmasher as cmr
from cmasher import cm as cmrcm
from cmasher.utils import (
combine_cmaps,
create_cmap_mod,
create_cmap_overview,
get_bibtex,
Expand Down Expand Up @@ -65,6 +66,124 @@ def _MPL38_colormap_eq(cmap, other) -> bool:


# %% PYTEST CLASSES AND FUNCTIONS
# Pytest class for combine_cmaps
class Test_combine_cmaps:
# Test if multiple Colormaps or colormap names can be combined
@pytest.mark.parametrize(
"cmaps, nodes",
[
(["Blues", "Oranges", "Greens"], [0.25, 0.75]),
(["Blues", "Oranges", "Greens"], np.array([0.25, 0.75])),
(
[
mpl.colormaps["Blues"],
mpl.colormaps["Oranges"],
mpl.colormaps["Greens"],
],
[0.25, 0.75],
),
],
)
def test_combine_cmaps(self, cmaps, nodes):
combined_cmap = combine_cmaps(*cmaps, nodes=nodes, n_rgb_levels=256)
blues_cmap = mpl.colormaps["Blues"]
oranges_cmap = mpl.colormaps["Oranges"]
greens_cmap = mpl.colormaps["Greens"]

assert np.allclose(combined_cmap(0.0), blues_cmap(0))
assert np.allclose(combined_cmap(0.25), oranges_cmap(0))
assert np.allclose(combined_cmap(0.75), greens_cmap(0))
assert np.allclose(combined_cmap(1.0), greens_cmap(255))

assert combined_cmap.N == 256

# Test combine cmaps with default nodes
def test_default_nodes(self):
combined_cmap = combine_cmaps("Blues", "Oranges", n_rgb_levels=256)

blues_cmap = mpl.colormaps["Blues"]
oranges_cmap = mpl.colormaps["Oranges"]

assert np.allclose(combined_cmap(0.0), blues_cmap(0))
assert np.allclose(combined_cmap(0.5), oranges_cmap(0))
assert np.allclose(combined_cmap(1.0), oranges_cmap(255))

# Test if combining less than 2 colormaps triggers an error
@pytest.mark.parametrize(
"cmaps",
[
pytest.param([], id="no_cmap"),
pytest.param(["Blues"], id="single_cmap"),
pytest.param(["fake_name"], id="fake_cmap_name"),
],
)
def test_not_enough_cmaps(self, cmaps):
with pytest.raises(
ValueError, match="Expected at least two colormaps to combine."
):
combine_cmaps(*cmaps)

# Test if invalid colormap name raise an error
def test_invalid_cmap_name(self):
with pytest.raises(
KeyError,
match="'fake_cmap' is not a known colormap name",
):
combine_cmaps("fake_cmap", "Blues")

# Test if invalid colormap types raise an error
@pytest.mark.parametrize(
"invalid_cmap",
[0, 0.0, [], ()],
)
def test_invalid_cmap_types(self, invalid_cmap):
with pytest.raises(
TypeError,
match=f"Unsupported colormap type: {type(invalid_cmap)}.",
):
combine_cmaps("Blues", invalid_cmap)

# Test if invalid nodes types raise an error
def test_invalid_nodes_types(self):
invalid_nodes = "0.5"
with pytest.raises(
TypeError,
match=f"Unsupported nodes type: {type(invalid_nodes)}, expect list of float.",
):
combine_cmaps("Blues", "Greens", nodes=invalid_nodes)

# Test if mismatch cmaps and nodes length raise an error
@pytest.mark.parametrize(
"cmaps, nodes",
[
(["Blues", "Oranges", "Greens"], [0.5]),
(["Reds", "Blues"], [0.2, 0.8]),
],
)
def test_cmaps_nodes_length_mismatch(self, cmaps, nodes):
with pytest.raises(
ValueError,
match=("Number of nodes should be one less than the number of colormaps."),
):
combine_cmaps(*cmaps, nodes=nodes)

# Test if invalid nodes raise an error
@pytest.mark.parametrize(
"cmaps, nodes",
[
(["Blues", "Oranges", "Greens"], [-1, 0.75]),
(["Blues", "Oranges", "Greens"], [0.25, 2]),
(["Blues", "Oranges", "Greens"], [0.75, 0.25]),
],
)
def test_invalid_nodes(self, cmaps, nodes):
with pytest.raises(
ValueError,
match="Nodes should only contain increasing values between 0.0 and 1.0.",
):
combine_cmaps(*cmaps, nodes=nodes)


# Pytest class for create_cmap_mod
class Test_create_cmap_mod:
# Test if a standalone module of rainforest can be created
Expand Down
112 changes: 111 additions & 1 deletion cmasher/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@

# Package imports
from colorspacious import cspace_converter
from matplotlib.colors import Colormap, ListedColormap as LC, to_hex, to_rgb
from matplotlib.colors import (
Colormap,
LinearSegmentedColormap,
ListedColormap as LC,
to_hex,
to_rgb,
)

# CMasher imports
from cmasher import cm as cmrcm
Expand All @@ -40,6 +46,7 @@

# All declaration
__all__ = [
"combine_cmaps",
"create_cmap_mod",
"create_cmap_overview",
"get_bibtex",
Expand Down Expand Up @@ -233,6 +240,109 @@ def _get_cmap_perceptual_rank(


# %% FUNCTIONS
# This function combines multiple colormaps at given nodes
def combine_cmaps(
*cmaps: Union[Colormap, str],
nodes: Optional[Union[list[float], np.ndarray]] = None,
n_rgb_levels: int = 256,
combined_cmap_name: str = "combined_cmap",
) -> LinearSegmentedColormap:
"""Create a composite matplotlib colormap by combining multiple colormaps.

Parameters
----------
*cmaps: Colormap or colormap name (str) to be combined.
nodes: list or numpy array of nodes (float). Defaults: equal divisions.
The blending points between colormaps, in the range [0, 1].
n_rgb_levels: int. Defaults: 256.
Number of RGB levels for each colormap segment.
combined_cmap_name: str. Defaults: "combined_cmap".
name of the combined Colormap.

Returns
-------
Colormap: The composite colormap.

Raises
------
TypeError: If the list contains mixed datatypes or invalid
colormap names.
ValueError: If the cmaps contain only one single colormap,
or if the number of nodes is not one less than the number
of colormaps, or if the nodes do not contain incrementing values
between 0.0 and 1.0.

Note
----
The colormaps are combined from low value to high value end.

References
----------
- https://stackoverflow.com/questions/31051488/combining-two-matplotlib-colormaps/31052741#31052741

Examples
--------
Using predefined colormap names::
>>> custom_cmap_1 = combine_cmaps(
["ocean", "prism", "coolwarm"], nodes=[0.2, 0.75]
)

Using Colormap objects::
>>> cmap_0 = plt.get_cmap("Blues")
>>> cmap_1 = plt.get_cmap("Oranges")
>>> cmap_2 = plt.get_cmap("Greens")
>>> custom_cmap_2 = combine_cmaps([cmap_0, cmap_1, cmap_2])

"""
# Check colormap datatype and convert to list[Colormap]
if len(cmaps) <= 1:
raise ValueError("Expected at least two colormaps to combine.")
for cm in cmaps:
if not isinstance(cm, (Colormap, str)):
raise TypeError(f"Unsupported colormap type: {type(cm)}.")
_cmaps: list[Colormap] = [
cm if isinstance(cm, Colormap) else mpl.colormaps[cm] for cm in cmaps
]

# Generate default nodes for equal separation
if nodes is None:
nodes_arr = np.linspace(0, 1, len(_cmaps) + 1)
elif isinstance(nodes, (list, np.ndarray)):
nodes_arr = np.concatenate([[0.0], nodes, [1.0]])
else:
raise TypeError(f"Unsupported nodes type: {type(nodes)}, expect list of float.")

# Check nodes length
if len(nodes_arr) != len(_cmaps) + 1:
raise ValueError(
"Number of nodes should be one less than the number of colormaps."
)

# Check node values
if any((nodes_arr < 0) | (nodes_arr > 1)) or any(np.diff(nodes_arr) <= 0):
raise ValueError(
"Nodes should only contain increasing values between 0.0 and 1.0."
)

# Generate composite colormap
combined_cmap_segments = []

for i, cmap in enumerate(_cmaps):
start_position = nodes_arr[i]
end_position = nodes_arr[i + 1]

# Calculate the length of the segment
segment_length = int(n_rgb_levels * (end_position - start_position))

# Append the segment to the combined colormap segments
combined_cmap_segments.append(cmap(np.linspace(0, 1, segment_length)))

# Combine the segments (from bottom to top)
return LinearSegmentedColormap.from_list(
combined_cmap_name, np.vstack(combined_cmap_segments)
)


# This function creates a standalone module of a CMasher colormap
def create_cmap_mod(
cmap: str, *, save_dir: str = ".", _copy_name: Optional[str] = None
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/user/images/combine_cmaps_equal.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions docs/source/user/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,29 @@ For that reason, below is an overview of all colormaps in *CMasher* (and the rev
Application overview plot of *CMasher*'s colormaps.


.. _combine_colormaps:

Combine colormaps
-----------------
*CMasher* offers a utility function :func:`~cmasher.combine_cmaps`, which enables the combination of multiple colormaps at specified ``nodes`` (where a node denotes the point separating adjacent colormaps, within the interval [0, 1]). You can directly pass several colormaps using the function like so :pycode:`combine_cmaps("cmr.rainforest", "cmr.torch_r")`. By default, each sub-colormap occupies an equal portion of the final colormap.

.. figure:: images/combine_cmaps_equal.png
:alt: Combine two colormaps with default equal separation.
:width: 100%
:align: center

Combine two colormaps with default equal separation.

Alternatively, you may want to specify ``nodes`` explicitly. For example :pycode:`cmr.combine_cmaps("cmr.rainforest", "cmr.torch_r", nodes=[0.75])` would allocate the starting 75% of the final colormap to "cmr.rainforest"and the remaining 25% to "cmr.torch_r".

.. figure:: images/combine_cmaps_0.75_0.25.png
:alt: Combine two colormaps with a 75%/25% separation.
:width: 100%
:align: center

Combine two colormaps with a 75%/25% separation.


Command-line interface (CLI)
----------------------------
Although *CMasher* is written in Python, some of its utility functions do not require the interpreter in order to be used properly.
Expand Down Expand Up @@ -216,6 +239,7 @@ The script and image below show an example of this::

Hexbin plot using a colormap legend entry for the :ref:`rainforest` colormap.


.. _sub_colormaps:

Sub-colormaps
Expand Down
Loading