Skip to content

Commit

Permalink
feat(python): Improve Polars Enum dtype init from standard Python e…
Browse files Browse the repository at this point in the history
…nums (#19997)
  • Loading branch information
alexander-beedie authored Nov 28, 2024
1 parent 99c7f4d commit b83d847
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 21 deletions.
32 changes: 28 additions & 4 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def __hash__(self) -> int:

class Enum(DataType):
"""
A fixed set categorical encoding of a set of strings.
A fixed categorical encoding of a unique set of strings.
.. warning::
This functionality is considered **unstable**.
Expand All @@ -592,8 +592,22 @@ class Enum(DataType):
Parameters
----------
categories
The categories in the dataset. Categories must be strings.
"""
The categories in the dataset; must be a unique set of strings, or an
existing Python string-valued enum.
Examples
--------
Explicitly define enumeration categories:
>>> pl.Enum(["north", "south", "east", "west"])
Enum(categories=['north', 'south', 'east', 'west'])
Initialise from an existing Python enumeration:
>>> from http import HTTPMethod
>>> pl.Enum(HTTPMethod)
Enum(categories=['CONNECT', 'DELETE', 'GET', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT', 'TRACE'])
""" # noqa: W505

categories: Series

Expand All @@ -608,7 +622,17 @@ def __init__(self, categories: Series | Iterable[str] | type[enum.Enum]) -> None
)

if isclass(categories) and issubclass(categories, enum.Enum):
categories = pl.Series(values=categories.__members__.values())
for enum_subclass in (enum.IntFlag, enum.Flag, enum.IntEnum):
if issubclass(categories, enum_subclass):
enum_type_name = enum_subclass.__name__
msg = f"Enum categories must be strings; Python `enum.{enum_type_name}` values are integers"
raise TypeError(msg)

enum_values = [
(v if isinstance(v, str) else v.value)
for v in categories.__members__.values()
]
categories = pl.Series(values=enum_values)
elif not isinstance(categories, pl.Series):
categories = pl.Series(values=categories)

Expand Down
73 changes: 56 additions & 17 deletions py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import enum
import operator
import re
import sys
from datetime import date
from textwrap import dedent
from typing import Any, Callable
Expand Down Expand Up @@ -42,32 +43,70 @@ def test_enum_init_empty(categories: pl.Series | list[str] | None) -> None:
assert_series_equal(dtype.categories, expected)


def test_enum_init_python_enum_19724() -> None:
class PythonEnum(str, enum.Enum):
CAT1 = "A"
CAT2 = "B"
CAT3 = "C"
def test_enum_init_from_python() -> None:
# standard string enum
class Color1(str, enum.Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"

result = pl.Enum(PythonEnum)
assert result == pl.Enum(["A", "B", "C"])
dtype = pl.Enum(Color1)
assert dtype == pl.Enum(["red", "green", "blue"])

# standard generic enum
class Color2(enum.Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"

def test_enum_init_python_enum_ints_19724() -> None:
class PythonEnum(int, enum.Enum):
CAT1 = 1
CAT2 = 2
CAT3 = 3
dtype = pl.Enum(Color2)
assert dtype == pl.Enum(["red", "green", "blue"])

with pytest.raises(TypeError, match="Enum categories must be strings"):
pl.Enum(PythonEnum)
# specialised string enum
if sys.version_info >= (3, 11):

class Color3(enum.Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"

dtype = pl.Enum(Color3)
assert dtype == pl.Enum(["red", "green", "blue"])


def test_enum_init_from_python_invalid() -> None:
class Color(int, enum.Enum):
RED = 1
GREEN = 2
BLUE = 3

with pytest.raises(
TypeError,
match="Enum categories must be strings",
):
pl.Enum(Color)

# flag/int enums
for EnumBase in (enum.Flag, enum.IntFlag, enum.IntEnum):

class Color(EnumBase): # type: ignore[no-redef,misc,valid-type]
RED = enum.auto()
GREEN = enum.auto()
BLUE = enum.auto()

base_name = EnumBase.__name__

with pytest.raises(
TypeError,
match=f"Enum categories must be strings; Python `enum.{base_name}` values are integers",
):
pl.Enum(Color)


def test_enum_non_existent() -> None:
with pytest.raises(
InvalidOperationError,
match=re.escape(
"conversion from `str` to `enum` failed in column '' for 1 out of 4 values: [\"c\"]"
),
match="conversion from `str` to `enum` failed in column '' for 1 out of 4 values: \\[\"c\"\\]",
):
pl.Series([None, "a", "b", "c"], dtype=pl.Enum(categories=["a", "b"]))

Expand Down

0 comments on commit b83d847

Please sign in to comment.