diff --git a/cmasher/tests/test_utils.py b/cmasher/tests/test_utils.py index 4024dcaf..1c0769d3 100644 --- a/cmasher/tests/test_utils.py +++ b/cmasher/tests/test_utils.py @@ -16,6 +16,7 @@ import cmasher as cmr from cmasher import cm as cmrcm from cmasher.utils import ( + combine_cmaps, create_cmap_mod, create_cmap_overview, get_bibtex, @@ -65,6 +66,124 @@ def _MPL38_colormap_eq(cmap, other) -> bool: # %% PYTEST CLASSES AND FUNCTIONS +# Pytest class for combine_cmaps +class Test_combine_cmaps: + # Test if multiple Colormaps or colormap names can be combined + @pytest.mark.parametrize( + "cmaps, nodes", + [ + (["Blues", "Oranges", "Greens"], [0.25, 0.75]), + (["Blues", "Oranges", "Greens"], np.array([0.25, 0.75])), + ( + [ + mpl.colormaps["Blues"], + mpl.colormaps["Oranges"], + mpl.colormaps["Greens"], + ], + [0.25, 0.75], + ), + ], + ) + def test_combine_cmaps(self, cmaps, nodes): + combined_cmap = combine_cmaps(*cmaps, nodes=nodes, n_rgb_levels=256) + blues_cmap = mpl.colormaps["Blues"] + oranges_cmap = mpl.colormaps["Oranges"] + greens_cmap = mpl.colormaps["Greens"] + + assert np.allclose(combined_cmap(0.0), blues_cmap(0)) + assert np.allclose(combined_cmap(0.25), oranges_cmap(0)) + assert np.allclose(combined_cmap(0.75), greens_cmap(0)) + assert np.allclose(combined_cmap(1.0), greens_cmap(255)) + + assert combined_cmap.N == 256 + + # Test combine cmaps with default nodes + def test_default_nodes(self): + combined_cmap = combine_cmaps("Blues", "Oranges", n_rgb_levels=256) + + blues_cmap = mpl.colormaps["Blues"] + oranges_cmap = mpl.colormaps["Oranges"] + + assert np.allclose(combined_cmap(0.0), blues_cmap(0)) + assert np.allclose(combined_cmap(0.5), oranges_cmap(0)) + assert np.allclose(combined_cmap(1.0), oranges_cmap(255)) + + # Test if combining less than 2 colormaps triggers an error + @pytest.mark.parametrize( + "cmaps", + [ + pytest.param([], id="no_cmap"), + pytest.param(["Blues"], id="single_cmap"), + pytest.param(["fake_name"], id="fake_cmap_name"), + ], + ) + def test_not_enough_cmaps(self, cmaps): + with pytest.raises( + ValueError, match="Expected at least two colormaps to combine." + ): + combine_cmaps(*cmaps) + + # Test if invalid colormap name raise an error + def test_invalid_cmap_name(self): + with pytest.raises( + KeyError, + match="'fake_cmap' is not a known colormap name", + ): + combine_cmaps("fake_cmap", "Blues") + + # Test if invalid colormap types raise an error + @pytest.mark.parametrize( + "invalid_cmap", + [0, 0.0, [], ()], + ) + def test_invalid_cmap_types(self, invalid_cmap): + with pytest.raises( + TypeError, + match=f"Unsupported colormap type: {type(invalid_cmap)}.", + ): + combine_cmaps("Blues", invalid_cmap) + + # Test if invalid nodes types raise an error + def test_invalid_nodes_types(self): + invalid_nodes = "0.5" + with pytest.raises( + TypeError, + match=f"Unsupported nodes type: {type(invalid_nodes)}, expect list of float.", + ): + combine_cmaps("Blues", "Greens", nodes=invalid_nodes) + + # Test if mismatch cmaps and nodes length raise an error + @pytest.mark.parametrize( + "cmaps, nodes", + [ + (["Blues", "Oranges", "Greens"], [0.5]), + (["Reds", "Blues"], [0.2, 0.8]), + ], + ) + def test_cmaps_nodes_length_mismatch(self, cmaps, nodes): + with pytest.raises( + ValueError, + match=("Number of nodes should be one less than the number of colormaps."), + ): + combine_cmaps(*cmaps, nodes=nodes) + + # Test if invalid nodes raise an error + @pytest.mark.parametrize( + "cmaps, nodes", + [ + (["Blues", "Oranges", "Greens"], [-1, 0.75]), + (["Blues", "Oranges", "Greens"], [0.25, 2]), + (["Blues", "Oranges", "Greens"], [0.75, 0.25]), + ], + ) + def test_invalid_nodes(self, cmaps, nodes): + with pytest.raises( + ValueError, + match="Nodes should only contain increasing values between 0.0 and 1.0.", + ): + combine_cmaps(*cmaps, nodes=nodes) + + # Pytest class for create_cmap_mod class Test_create_cmap_mod: # Test if a standalone module of rainforest can be created diff --git a/cmasher/utils.py b/cmasher/utils.py index b8f2e0c6..44d8eaf2 100644 --- a/cmasher/utils.py +++ b/cmasher/utils.py @@ -21,7 +21,13 @@ # Package imports from colorspacious import cspace_converter -from matplotlib.colors import Colormap, ListedColormap as LC, to_hex, to_rgb +from matplotlib.colors import ( + Colormap, + LinearSegmentedColormap, + ListedColormap as LC, + to_hex, + to_rgb, +) # CMasher imports from cmasher import cm as cmrcm @@ -40,6 +46,7 @@ # All declaration __all__ = [ + "combine_cmaps", "create_cmap_mod", "create_cmap_overview", "get_bibtex", @@ -233,6 +240,109 @@ def _get_cmap_perceptual_rank( # %% FUNCTIONS +# This function combines multiple colormaps at given nodes +def combine_cmaps( + *cmaps: Union[Colormap, str], + nodes: Optional[Union[list[float], np.ndarray]] = None, + n_rgb_levels: int = 256, + combined_cmap_name: str = "combined_cmap", +) -> LinearSegmentedColormap: + """Create a composite matplotlib colormap by combining multiple colormaps. + + Parameters + ---------- + *cmaps: Colormap or colormap name (str) to be combined. + nodes: list or numpy array of nodes (float). Defaults: equal divisions. + The blending points between colormaps, in the range [0, 1]. + n_rgb_levels: int. Defaults: 256. + Number of RGB levels for each colormap segment. + combined_cmap_name: str. Defaults: "combined_cmap". + name of the combined Colormap. + + Returns + ------- + Colormap: The composite colormap. + + Raises + ------ + TypeError: If the list contains mixed datatypes or invalid + colormap names. + ValueError: If the cmaps contain only one single colormap, + or if the number of nodes is not one less than the number + of colormaps, or if the nodes do not contain incrementing values + between 0.0 and 1.0. + + Note + ---- + The colormaps are combined from low value to high value end. + + References + ---------- + - https://stackoverflow.com/questions/31051488/combining-two-matplotlib-colormaps/31052741#31052741 + + Examples + -------- + Using predefined colormap names:: + >>> custom_cmap_1 = combine_cmaps( + ["ocean", "prism", "coolwarm"], nodes=[0.2, 0.75] + ) + + Using Colormap objects:: + >>> cmap_0 = plt.get_cmap("Blues") + >>> cmap_1 = plt.get_cmap("Oranges") + >>> cmap_2 = plt.get_cmap("Greens") + >>> custom_cmap_2 = combine_cmaps([cmap_0, cmap_1, cmap_2]) + + """ + # Check colormap datatype and convert to list[Colormap] + if len(cmaps) <= 1: + raise ValueError("Expected at least two colormaps to combine.") + for cm in cmaps: + if not isinstance(cm, (Colormap, str)): + raise TypeError(f"Unsupported colormap type: {type(cm)}.") + _cmaps: list[Colormap] = [ + cm if isinstance(cm, Colormap) else mpl.colormaps[cm] for cm in cmaps + ] + + # Generate default nodes for equal separation + if nodes is None: + nodes_arr = np.linspace(0, 1, len(_cmaps) + 1) + elif isinstance(nodes, (list, np.ndarray)): + nodes_arr = np.concatenate([[0.0], nodes, [1.0]]) + else: + raise TypeError(f"Unsupported nodes type: {type(nodes)}, expect list of float.") + + # Check nodes length + if len(nodes_arr) != len(_cmaps) + 1: + raise ValueError( + "Number of nodes should be one less than the number of colormaps." + ) + + # Check node values + if any((nodes_arr < 0) | (nodes_arr > 1)) or any(np.diff(nodes_arr) <= 0): + raise ValueError( + "Nodes should only contain increasing values between 0.0 and 1.0." + ) + + # Generate composite colormap + combined_cmap_segments = [] + + for i, cmap in enumerate(_cmaps): + start_position = nodes_arr[i] + end_position = nodes_arr[i + 1] + + # Calculate the length of the segment + segment_length = int(n_rgb_levels * (end_position - start_position)) + + # Append the segment to the combined colormap segments + combined_cmap_segments.append(cmap(np.linspace(0, 1, segment_length))) + + # Combine the segments (from bottom to top) + return LinearSegmentedColormap.from_list( + combined_cmap_name, np.vstack(combined_cmap_segments) + ) + + # This function creates a standalone module of a CMasher colormap def create_cmap_mod( cmap: str, *, save_dir: str = ".", _copy_name: Optional[str] = None diff --git a/docs/source/user/images/combine_cmaps_0.75_0.25.png b/docs/source/user/images/combine_cmaps_0.75_0.25.png new file mode 100644 index 00000000..e91d73d1 Binary files /dev/null and b/docs/source/user/images/combine_cmaps_0.75_0.25.png differ diff --git a/docs/source/user/images/combine_cmaps_equal.png b/docs/source/user/images/combine_cmaps_equal.png new file mode 100644 index 00000000..05147eb1 Binary files /dev/null and b/docs/source/user/images/combine_cmaps_equal.png differ diff --git a/docs/source/user/usage.rst b/docs/source/user/usage.rst index 00f29442..38c87595 100644 --- a/docs/source/user/usage.rst +++ b/docs/source/user/usage.rst @@ -111,6 +111,29 @@ For that reason, below is an overview of all colormaps in *CMasher* (and the rev Application overview plot of *CMasher*'s colormaps. +.. _combine_colormaps: + +Combine colormaps +----------------- +*CMasher* offers a utility function :func:`~cmasher.combine_cmaps`, which enables the combination of multiple colormaps at specified ``nodes`` (where a node denotes the point separating adjacent colormaps, within the interval [0, 1]). You can directly pass several colormaps using the function like so :pycode:`combine_cmaps("cmr.rainforest", "cmr.torch_r")`. By default, each sub-colormap occupies an equal portion of the final colormap. + +.. figure:: images/combine_cmaps_equal.png + :alt: Combine two colormaps with default equal separation. + :width: 100% + :align: center + + Combine two colormaps with default equal separation. + +Alternatively, you may want to specify ``nodes`` explicitly. For example :pycode:`cmr.combine_cmaps("cmr.rainforest", "cmr.torch_r", nodes=[0.75])` would allocate the starting 75% of the final colormap to "cmr.rainforest"and the remaining 25% to "cmr.torch_r". + +.. figure:: images/combine_cmaps_0.75_0.25.png + :alt: Combine two colormaps with a 75%/25% separation. + :width: 100% + :align: center + + Combine two colormaps with a 75%/25% separation. + + Command-line interface (CLI) ---------------------------- Although *CMasher* is written in Python, some of its utility functions do not require the interpreter in order to be used properly. @@ -216,6 +239,7 @@ The script and image below show an example of this:: Hexbin plot using a colormap legend entry for the :ref:`rainforest` colormap. + .. _sub_colormaps: Sub-colormaps