diff --git a/enum_tools/demo.py b/enum_tools/demo.py index fd35d0a..69d956e 100644 --- a/enum_tools/demo.py +++ b/enum_tools/demo.py @@ -71,7 +71,7 @@ class StatusFlags(IntFlag): Stopped = 2 # doc: The system has stopped. Error = 4 # doc: An error has occurred. - def has_errored(self) -> bool: + def has_errored(self) -> bool: # pragma: no cover """ Returns whether the operation has errored. """ diff --git a/enum_tools/utils.py b/enum_tools/utils.py index e31f3b5..da71065 100644 --- a/enum_tools/utils.py +++ b/enum_tools/utils.py @@ -87,9 +87,16 @@ def get_base_object(enum: Type[HasMRO]) -> Type: If the members are of indeterminate type then the :class:`object` class is returned. :param enum: + + :rtype: + + :raises TypeError: If ``enum`` is not an Enum. """ - mro = inspect.getmro(enum) + try: + mro = inspect.getmro(enum) + except AttributeError: + raise TypeError("not an Enum") if Flag in mro: mro = mro[:mro.index(Flag)] diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..c701749 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,83 @@ +# stdlib +import enum +import http + +# 3rd party +import pytest + +# this package +from enum_tools import StrEnum +from enum_tools.utils import get_base_object, is_enum, is_enum_member, is_flag + + +@pytest.mark.parametrize( + "obj, result", + [ + (enum.Enum, True), + (http.HTTPStatus, True), + (http.HTTPStatus.NOT_ACCEPTABLE, False), + (123, False), + ("abc", False), + ] + ) +def test_is_enum(obj, result): + assert is_enum(obj) == result + + +@pytest.mark.parametrize( + "obj, result", + [ + (enum.Enum, False), + (http.HTTPStatus, False), + (http.HTTPStatus.NOT_ACCEPTABLE, True), + (123, False), + ("abc", False), + ] + ) +def test_is_enum_member(obj, result): + assert is_enum_member(obj) == result + + +class Colours(enum.Flag): + RED = 1 + BLUE = 2 + + +PURPLE = Colours.RED | Colours.BLUE + + +@pytest.mark.parametrize( + "obj, result", + [ + (enum.Enum, False), + (http.HTTPStatus, False), + (http.HTTPStatus.NOT_ACCEPTABLE, False), + (123, False), + ("abc", False), + (Colours, True), + (Colours.RED, False), + (PURPLE, False), + ] + ) +def test_is_flag(obj, result): + assert is_flag(obj) == result + + +def test_get_base_object(): + # TODO: report issue to mypy + assert get_base_object(enum.Enum) is object # type: ignore[arg-type] + assert get_base_object(Colours) is object # type: ignore[arg-type] + assert get_base_object(enum.IntFlag) is int # type: ignore[arg-type] + assert get_base_object(StrEnum) is str # type: ignore[arg-type] + + with pytest.raises(TypeError, match="not an Enum"): + get_base_object("abc") # type: ignore[arg-type] + + with pytest.raises(TypeError, match="not an Enum"): + get_base_object(123) # type: ignore[arg-type] + + with pytest.raises(TypeError, match="not an Enum"): + get_base_object(str) # type: ignore[arg-type] + + with pytest.raises(TypeError, match="not an Enum"): + get_base_object(int) # type: ignore[arg-type]