Skip to content

Commit

Permalink
fix(polars): support polars Enum type
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Sep 4, 2024
1 parent ac79604 commit 1546b87
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
23 changes: 23 additions & 0 deletions ibis/backends/polars/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import polars as pl
import polars.testing
import pytest

import ibis
Expand Down Expand Up @@ -37,3 +39,24 @@ def test_array_flatten(con):
{"id": data["id"], "flat": [row[0] for row in data["happy"]]}
)
tm.assert_frame_equal(result.to_pandas(), expected)


def test_memtable_polars_types(con):
# Check that we can create a memtable with some polars-specific types,
# and that those columns then work in downstream operations
df = pl.DataFrame(
{
"x": ["a", "b", "a"],
"y": ["c", "d", "c"],
"z": ["e", "f", "e"],
},
schema={
"x": pl.String,
"y": pl.Categorical,
"z": pl.Enum(["e", "f"]),
},
)
t = ibis.memtable(df)
res = (t.x + t.y + t.z).name("test").to_polars()
sol = (df["x"] + df["y"] + df["z"]).rename("test")
pl.testing.assert_series_equal(res, sol)
2 changes: 1 addition & 1 deletion ibis/formats/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def to_ibis(cls, typ: pl.DataType, nullable=True) -> dt.DataType:
"""Convert a polars type to an ibis type."""

base_type = typ.base_type()
if base_type is pl.Categorical:
if base_type in (pl.Categorical, pl.Enum):
return dt.String(nullable=nullable)
elif base_type is pl.Decimal:
return dt.Decimal(
Expand Down
3 changes: 2 additions & 1 deletion ibis/formats/tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def test_decimal():
)


def test_categorical():
def test_enum_categorical():
assert PolarsType.to_ibis(pl.Categorical()) == dt.string
assert PolarsType.to_ibis(pl.Enum(["a", "b"])) == dt.string


def test_interval_unsupported_unit():
Expand Down

0 comments on commit 1546b87

Please sign in to comment.