Skip to content

Commit 530cb9e

Browse files
committed
TYP: improve type annotations for take_cmap_colors
1 parent 9bebc27 commit 530cb9e

File tree

1 file changed

+44
-18
lines changed

1 file changed

+44
-18
lines changed

src/cmasher/utils.py

+44-18
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from itertools import chain
1313
from pathlib import Path
1414
from textwrap import dedent
15-
from typing import TYPE_CHECKING, NewType
15+
from typing import TYPE_CHECKING, NewType, overload
1616

1717
import matplotlib as mpl
1818
import numpy as np
@@ -33,7 +33,7 @@
3333
import os
3434
import sys
3535
from collections.abc import Callable, Iterator
36-
from typing import Literal, Protocol, TypeAlias
36+
from typing import Literal, Protocol, TypeAlias, TypeVar
3737

3838
from matplotlib.artist import Artist
3939
from numpy.typing import NDArray
@@ -43,6 +43,9 @@
4343
else:
4444
from typing_extensions import Self
4545

46+
T = TypeVar("T", int, float)
47+
RGB: TypeAlias = tuple[T, T, T]
48+
4649
class SupportsDunderLT(Protocol):
4750
def __lt__(self, other: Self, /) -> bool: ...
4851

@@ -51,6 +54,7 @@ def __gt__(self, other: Self, /) -> bool: ...
5154

5255
SupportsOrdering: TypeAlias = SupportsDunderLT | SupportsDunderGT
5356

57+
5458
_HAS_VISCM = find_spec("viscm") is not None
5559

5660
# All declaration
@@ -78,12 +82,6 @@ def __gt__(self, other: Self, /) -> bool: ...
7882
Category = NewType("Category", str)
7983
Name = NewType("Name", str)
8084

81-
# Type aliases
82-
RED: TypeAlias = float
83-
GREEN: TypeAlias = float
84-
BLUE: TypeAlias = float
85-
RGB: TypeAlias = list[tuple[RED, GREEN, BLUE]]
86-
8785

8886
# %% HELPER FUNCTIONS
8987
# Define function for obtaining the sorting order for lightness ranking
@@ -1436,13 +1434,43 @@ def set_cmap_legend_entry(artist: Artist, label: str) -> None:
14361434

14371435

14381436
# Function to take N equally spaced colors from a colormap
1437+
@overload
1438+
def take_cmap_colors(
1439+
cmap: Colormap | Name,
1440+
N: int | None,
1441+
*,
1442+
cmap_range: tuple[float, float] = (0, 1),
1443+
return_fmt: Literal["float", "norm"] = "float",
1444+
) -> RGB[float]: ...
1445+
1446+
1447+
@overload
14391448
def take_cmap_colors(
14401449
cmap: Colormap | Name,
14411450
N: int | None,
14421451
*,
14431452
cmap_range: tuple[float, float] = (0, 1),
1444-
return_fmt: str = "float",
1445-
) -> RGB:
1453+
return_fmt: Literal["int", "8bit"],
1454+
) -> RGB[int]: ...
1455+
1456+
1457+
@overload
1458+
def take_cmap_colors(
1459+
cmap: Colormap | Name,
1460+
N: int | None,
1461+
*,
1462+
cmap_range: tuple[float, float] = (0, 1),
1463+
return_fmt: Literal["str", "hex"],
1464+
) -> list[str]: ...
1465+
1466+
1467+
def take_cmap_colors(
1468+
cmap: Colormap | Name,
1469+
N: int | None,
1470+
*,
1471+
cmap_range: tuple[float, float] = (0, 1),
1472+
return_fmt: Literal["float", "norm", "int", "8bit", "str", "hex"] = "float",
1473+
) -> RGB[float] | RGB[int] | list[str]:
14461474
"""
14471475
Takes `N` equally spaced colors from the provided colormap `cmap` and
14481476
returns them.
@@ -1514,9 +1542,6 @@ def take_cmap_colors(
15141542
that describe the same property, but have a different initial state.
15151543
15161544
"""
1517-
# Convert provided fmt to lowercase
1518-
return_fmt = return_fmt.lower()
1519-
15201545
# Obtain the colormap
15211546
if isinstance(cmap, str):
15221547
cmap = mpl.colormaps[cmap]
@@ -1544,12 +1569,13 @@ def take_cmap_colors(
15441569
colors = np.apply_along_axis(to_rgb, 1, colors) # type: ignore [call-overload]
15451570
if return_fmt in ("int", "8bit"):
15461571
colors = np.array(np.rint(colors * 255), dtype=int)
1547-
colors = list(map(tuple, colors))
1572+
return [(int(c[0]), int(c[1]), int(c[2])) for c in colors] # type: ignore [misc]
1573+
else:
1574+
return [(float(c[0]), float(c[1]), float(c[2])) for c in colors] # type: ignore [misc]
1575+
elif return_fmt in ("str", "hex"):
1576+
return [to_hex(x).upper() for x in colors]
15481577
else:
1549-
colors = [to_hex(x).upper() for x in colors]
1550-
1551-
# Return colors
1552-
return colors
1578+
raise ValueError(return_fmt)
15531579

15541580

15551581
# Function to view what a colormap looks like

0 commit comments

Comments
 (0)