Skip to content

Commit

Permalink
Add cat.to_enum
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Mar 12, 2024
1 parent 8c2eca9 commit bef1dd5
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 3 deletions.
14 changes: 14 additions & 0 deletions crates/polars-core/src/chunked_array/logical/categorical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,20 @@ impl CategoricalChunked {
}
}

// Convert to fixed enum using existing categories.
pub fn convert_to_enum(&self) -> Self {
let s = self.to_local();
// SAFETY: we create the physical directly from self
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
s.physical().clone(),
s.get_rev_map().clone(),
true,
s.get_ordering(),
)
}
}

// Convert to fixed enum. Values not in categories are mapped to None.
pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> Self {
// Fast paths
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-plan/src/dsl/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,8 @@ impl CategoricalNameSpace {
self.0
.map_private(CategoricalFunction::GetCategories.into())
}

pub fn to_enum(self) -> Expr {
self.0.map_private(CategoricalFunction::ToEnum.into())
}
}
9 changes: 9 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ use crate::map;
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
pub enum CategoricalFunction {
GetCategories,
ToEnum,
}

impl CategoricalFunction {
pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
use CategoricalFunction::*;
match self {
GetCategories => mapper.with_dtype(DataType::String),
ToEnum => mapper.map_cat_to_enum_dtype(),
}
}
}
Expand All @@ -21,6 +23,7 @@ impl Display for CategoricalFunction {
use CategoricalFunction::*;
let s = match self {
GetCategories => "get_categories",
ToEnum => "to_enum",
};
write!(f, "cat.{s}")
}
Expand All @@ -31,6 +34,7 @@ impl From<CategoricalFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
use CategoricalFunction::*;
match func {
GetCategories => map!(get_categories),
ToEnum => map!(to_enum),
}
}
}
Expand All @@ -48,3 +52,8 @@ fn get_categories(s: &Series) -> PolarsResult<Series> {
let arr = rev_map.get_categories().clone().boxed();
Series::try_from((ca.name(), arr))
}

fn to_enum(s: &Series) -> PolarsResult<Series> {
let ca = s.categorical()?.to_local();
Ok(ca.convert_to_enum().into_series())
}
16 changes: 16 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,22 @@ impl<'a> FieldsMapper<'a> {
})
}

// Map a categorical to an Enum type
pub fn map_cat_to_enum_dtype(&self) -> PolarsResult<Field> {
self.map_dtype(|dtype| match dtype {
DataType::Categorical(rev_map, ordering) => DataType::Enum(
rev_map.clone().map(|rm| match &*rm {
RevMapping::Global(..) => {
Arc::new(RevMapping::build_local(rm.get_categories().clone()))
},
_ => rm,
}),
*ordering,
),
_ => dtype.clone(),
})
}

/// Map to a physical type.
pub fn to_physical_type(&self) -> PolarsResult<Field> {
self.map_dtype(|dtype| dtype.to_physical())
Expand Down
22 changes: 22 additions & 0 deletions py-polars/polars/expr/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,25 @@ def get_categories(self) -> Expr:
└──────┘
"""
return wrap_expr(self._pyexpr.cat_get_categories())

def to_enum(self) -> Expr:
"""
Convert a categorical column to an Enum.
The Enum column retains all categories of the Categorical column, regardless of
whether those categories are represented in the column.
>>> df = pl.Series(
... "cats", ["foo", "bar", "foo", "foo", "ham"], dtype=pl.Categorical
... ).to_frame()
>>> df.select(pl.col("cats").cat.to_enum())
shape: (2,)
Series: '' [enum]
[
"a"
"b"
]
>>> s.dtype
Enum(categories=['a', 'b', 'c'])
"""
return wrap_expr(self._pyexpr.cat_to_enum())
27 changes: 24 additions & 3 deletions py-polars/polars/series/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_categories(self) -> Series:

def is_local(self) -> bool:
"""
Return whether or not the column is a local categorical.
Return whether or not the series is a local categorical.
Examples
--------
Expand All @@ -87,9 +87,9 @@ def is_local(self) -> bool:

def to_local(self) -> Series:
"""
Convert a categorical column to its local representation.
Convert a categorical series to its local representation.
This may change the underlying physical representation of the column.
This may change the underlying physical representation of the series.
See the documentation of :func:`StringCache` for more information on the
difference between local and global categoricals.
Expand Down Expand Up @@ -120,6 +120,27 @@ def to_local(self) -> Series:
"""
return wrap_s(self._s.cat_to_local())

def to_enum(self) -> Series:
"""
Convert a categorical Series to an Enum series.
The Enum series retains all categories of the Categorical series, regardless of
whether those categories are represented in the series.
>>> s = pl.Series(["a", "b"], dtype=pl.Categorical(categories=["a", "b", "c"]))
>>> s_enum s.to_enum()
>>> s
shape: (2,)
Series: '' [enum]
[
"a"
"b"
]
>>> s.dtype
Enum(categories=['a', 'b', 'c'])
"""
return wrap_s(self._s.cat_to_enum())

@unstable()
def uses_lexical_ordering(self) -> bool:
"""
Expand Down
4 changes: 4 additions & 0 deletions py-polars/src/expr/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,8 @@ impl PyExpr {
fn cat_get_categories(&self) -> Self {
self.inner.clone().cat().get_categories().into()
}

fn cat_to_enum(&self) -> Self {
self.inner.clone().cat().to_enum().into()
}
}
5 changes: 5 additions & 0 deletions py-polars/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ impl PySeries {
Ok(ca.to_local().into_series().into())
}

pub fn cat_to_enum(&self) -> PyResult<Self> {
let ca = self.series.categorical().map_err(PyPolarsErr::from)?;
Ok(ca.convert_to_enum().into_series().into())
}

fn estimated_size(&self) -> usize {
self.series.estimated_size()
}
Expand Down
30 changes: 30 additions & 0 deletions py-polars/tests/unit/namespaces/test_categorical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import polars as pl
from polars import StringCache
from polars.testing import assert_frame_equal


Expand Down Expand Up @@ -142,3 +143,32 @@ def test_cat_uses_lexical_ordering() -> None:

s = s.cast(pl.Categorical("physical"))
assert s.cat.uses_lexical_ordering() is False


def test_cat_local_to_enum() -> None:
s = pl.Series("s", ["a", "b", None, "c"], dtype=pl.Categorical)
s = s[:3] # extra categories not found in Series
out_s = s.cat.to_enum()
assert out_s.dtype == pl.Enum(["a", "b", "c"])
assert out_s.to_list() == ["a", "b", None]

out_df = pl.DataFrame(s).select(pl.col("s").cat.to_enum())
assert out_df["s"].dtype == pl.Enum(["a", "b", "c"])
assert out_df["s"].to_list() == ["a", "b", None]


@StringCache()
def test_cat_global_to_enum() -> None:
# pre-set global cache index
_ = pl.Series(["d", "a", "c", "b"])

s = pl.Series("s", ["a", "b", None, "c"], dtype=pl.Categorical)
s = s[:3] # extra categories not found in Series
out_s = s.cat.to_enum()
assert out_s.dtype == pl.Enum(["a", "b", "c"])
assert out_s.to_list() == ["a", "b", None]

df = pl.DataFrame(s)
out_df = df.select(pl.col("s").cat.to_enum())
assert out_df["s"].dtype == pl.Enum(["a", "b", "c"])
assert out_df["s"].to_list() == ["a", "b", None]
27 changes: 27 additions & 0 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

import polars as pl
from polars import StringCache
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -632,3 +633,29 @@ def test_literal_subtract_schema_13284() -> None:
.group_by("a")
.len()
).schema == OrderedDict([("a", pl.UInt8), ("len", pl.UInt32)])


def test_local_cat_to_enum() -> None:
lf = pl.LazyFrame(
{
"cat": pl.Series(["a", "b"], dtype=pl.Categorical(ordering="physical")),
}
)

schema = lf.select(pl.col("cat").cat.to_enum()).schema
assert schema == OrderedDict({"cat": pl.Enum(categories=["a", "b"])})


@StringCache()
def test_global_cat_to_enum() -> None:
# pre-set global cache index
_ = pl.Series(["d", "a", "c", "b"], dtype=pl.Categorical)

lf = pl.LazyFrame(
{
"cat": pl.Series(["a", "b", None], dtype=pl.Categorical),
}
)

schema = lf.select(pl.col("cat").cat.to_enum()).schema
assert schema == OrderedDict({"cat": pl.Enum(categories=["a", "b"])})

0 comments on commit bef1dd5

Please sign in to comment.