diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 959c1f5ec666..ebdd36036045 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -130,6 +130,20 @@ impl CategoricalChunked { } } + // Convert to fixed enum using existing categories. + pub fn convert_to_enum(&self) -> Self { + let s = self.to_local(); + // SAFETY: we create the physical directly from self + unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + s.physical().clone(), + s.get_rev_map().clone(), + true, + s.get_ordering(), + ) + } + } + // Convert to fixed enum. Values not in categories are mapped to None. pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> Self { // Fast paths diff --git a/crates/polars-plan/src/dsl/cat.rs b/crates/polars-plan/src/dsl/cat.rs index 8ffe200fc7e0..f9201c7dbdf1 100644 --- a/crates/polars-plan/src/dsl/cat.rs +++ b/crates/polars-plan/src/dsl/cat.rs @@ -8,4 +8,8 @@ impl CategoricalNameSpace { self.0 .map_private(CategoricalFunction::GetCategories.into()) } + + pub fn to_enum(self) -> Expr { + self.0.map_private(CategoricalFunction::ToEnum.into()) + } } diff --git a/crates/polars-plan/src/dsl/function_expr/cat.rs b/crates/polars-plan/src/dsl/function_expr/cat.rs index db50f4ef4429..f75878a3e307 100644 --- a/crates/polars-plan/src/dsl/function_expr/cat.rs +++ b/crates/polars-plan/src/dsl/function_expr/cat.rs @@ -5,6 +5,7 @@ use crate::map; #[derive(Clone, PartialEq, Debug, Eq, Hash)] pub enum CategoricalFunction { GetCategories, + ToEnum, } impl CategoricalFunction { @@ -12,6 +13,7 @@ impl CategoricalFunction { use CategoricalFunction::*; match self { GetCategories => mapper.with_dtype(DataType::String), + ToEnum => mapper.map_cat_to_enum_dtype(), } } } @@ -21,6 +23,7 @@ impl Display for CategoricalFunction { use CategoricalFunction::*; let s = match self { GetCategories => "get_categories", + ToEnum => "to_enum", }; write!(f, "cat.{s}") } @@ -31,6 +34,7 @@ impl From for SpecialEq> { use CategoricalFunction::*; match func { GetCategories => map!(get_categories), + ToEnum => map!(to_enum), } } } @@ -48,3 +52,8 @@ fn get_categories(s: &Series) -> PolarsResult { let arr = rev_map.get_categories().clone().boxed(); Series::try_from((ca.name(), arr)) } + +fn to_enum(s: &Series) -> PolarsResult { + let ca = s.categorical()?.to_local(); + Ok(ca.convert_to_enum().into_series()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 6a711353c52a..39c38e230ea5 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -356,6 +356,22 @@ impl<'a> FieldsMapper<'a> { }) } + // Map a categorical to an Enum type + pub fn map_cat_to_enum_dtype(&self) -> PolarsResult { + self.map_dtype(|dtype| match dtype { + DataType::Categorical(rev_map, ordering) => DataType::Enum( + rev_map.clone().map(|rm| match &*rm { + RevMapping::Global(..) => { + Arc::new(RevMapping::build_local(rm.get_categories().clone())) + }, + _ => rm, + }), + *ordering, + ), + _ => dtype.clone(), + }) + } + /// Map to a physical type. pub fn to_physical_type(&self) -> PolarsResult { self.map_dtype(|dtype| dtype.to_physical()) diff --git a/py-polars/polars/expr/categorical.py b/py-polars/polars/expr/categorical.py index ca00114c4e36..4d6b5f920aff 100644 --- a/py-polars/polars/expr/categorical.py +++ b/py-polars/polars/expr/categorical.py @@ -61,3 +61,25 @@ def get_categories(self) -> Expr: └──────┘ """ return wrap_expr(self._pyexpr.cat_get_categories()) + + def to_enum(self) -> Expr: + """ + Convert a categorical column to an Enum. + + The Enum column retains all categories of the Categorical column, regardless of + whether those categories are represented in the column. + + >>> df = pl.Series( + ... "cats", ["foo", "bar", "foo", "foo", "ham"], dtype=pl.Categorical + ... ).to_frame() + >>> df.select(pl.col("cats").cat.to_enum()) + shape: (2,) + Series: '' [enum] + [ + "a" + "b" + ] + >>> s.dtype + Enum(categories=['a', 'b', 'c']) + """ + return wrap_expr(self._pyexpr.cat_to_enum()) diff --git a/py-polars/polars/series/categorical.py b/py-polars/polars/series/categorical.py index 204a125853a9..af924a4642fe 100644 --- a/py-polars/polars/series/categorical.py +++ b/py-polars/polars/series/categorical.py @@ -66,7 +66,7 @@ def get_categories(self) -> Series: def is_local(self) -> bool: """ - Return whether or not the column is a local categorical. + Return whether or not the series is a local categorical. Examples -------- @@ -87,9 +87,9 @@ def is_local(self) -> bool: def to_local(self) -> Series: """ - Convert a categorical column to its local representation. + Convert a categorical series to its local representation. - This may change the underlying physical representation of the column. + This may change the underlying physical representation of the series. See the documentation of :func:`StringCache` for more information on the difference between local and global categoricals. @@ -120,6 +120,27 @@ def to_local(self) -> Series: """ return wrap_s(self._s.cat_to_local()) + def to_enum(self) -> Series: + """ + Convert a categorical Series to an Enum series. + + The Enum series retains all categories of the Categorical series, regardless of + whether those categories are represented in the series. + + >>> s = pl.Series(["a", "b"], dtype=pl.Categorical(categories=["a", "b", "c"])) + >>> s_enum s.to_enum() + >>> s + shape: (2,) + Series: '' [enum] + [ + "a" + "b" + ] + >>> s.dtype + Enum(categories=['a', 'b', 'c']) + """ + return wrap_s(self._s.cat_to_enum()) + @unstable() def uses_lexical_ordering(self) -> bool: """ diff --git a/py-polars/src/expr/categorical.rs b/py-polars/src/expr/categorical.rs index c5b3971017ca..a4fa80457e13 100644 --- a/py-polars/src/expr/categorical.rs +++ b/py-polars/src/expr/categorical.rs @@ -16,4 +16,8 @@ impl PyExpr { fn cat_get_categories(&self) -> Self { self.inner.clone().cat().get_categories().into() } + + fn cat_to_enum(&self) -> Self { + self.inner.clone().cat().to_enum().into() + } } diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index a523ee6439ea..833906fd051f 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -111,6 +111,11 @@ impl PySeries { Ok(ca.to_local().into_series().into()) } + pub fn cat_to_enum(&self) -> PyResult { + let ca = self.series.categorical().map_err(PyPolarsErr::from)?; + Ok(ca.convert_to_enum().into_series().into()) + } + fn estimated_size(&self) -> usize { self.series.estimated_size() } diff --git a/py-polars/tests/unit/namespaces/test_categorical.py b/py-polars/tests/unit/namespaces/test_categorical.py index 03ca6497eb0d..41cdbf50a141 100644 --- a/py-polars/tests/unit/namespaces/test_categorical.py +++ b/py-polars/tests/unit/namespaces/test_categorical.py @@ -1,4 +1,5 @@ import polars as pl +from polars import StringCache from polars.testing import assert_frame_equal @@ -142,3 +143,32 @@ def test_cat_uses_lexical_ordering() -> None: s = s.cast(pl.Categorical("physical")) assert s.cat.uses_lexical_ordering() is False + + +def test_cat_local_to_enum() -> None: + s = pl.Series("s", ["a", "b", None, "c"], dtype=pl.Categorical) + s = s[:3] # extra categories not found in Series + out_s = s.cat.to_enum() + assert out_s.dtype == pl.Enum(["a", "b", "c"]) + assert out_s.to_list() == ["a", "b", None] + + out_df = pl.DataFrame(s).select(pl.col("s").cat.to_enum()) + assert out_df["s"].dtype == pl.Enum(["a", "b", "c"]) + assert out_df["s"].to_list() == ["a", "b", None] + + +@StringCache() +def test_cat_global_to_enum() -> None: + # pre-set global cache index + _ = pl.Series(["d", "a", "c", "b"]) + + s = pl.Series("s", ["a", "b", None, "c"], dtype=pl.Categorical) + s = s[:3] # extra categories not found in Series + out_s = s.cat.to_enum() + assert out_s.dtype == pl.Enum(["a", "b", "c"]) + assert out_s.to_list() == ["a", "b", None] + + df = pl.DataFrame(s) + out_df = df.select(pl.col("s").cat.to_enum()) + assert out_df["s"].dtype == pl.Enum(["a", "b", "c"]) + assert out_df["s"].to_list() == ["a", "b", None] diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index d3ee82446efe..86725ae89a2d 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -7,6 +7,7 @@ import pytest import polars as pl +from polars import StringCache from polars.testing import assert_frame_equal, assert_series_equal @@ -632,3 +633,29 @@ def test_literal_subtract_schema_13284() -> None: .group_by("a") .len() ).schema == OrderedDict([("a", pl.UInt8), ("len", pl.UInt32)]) + + +def test_local_cat_to_enum() -> None: + lf = pl.LazyFrame( + { + "cat": pl.Series(["a", "b"], dtype=pl.Categorical(ordering="physical")), + } + ) + + schema = lf.select(pl.col("cat").cat.to_enum()).schema + assert schema == OrderedDict({"cat": pl.Enum(categories=["a", "b"])}) + + +@StringCache() +def test_global_cat_to_enum() -> None: + # pre-set global cache index + _ = pl.Series(["d", "a", "c", "b"], dtype=pl.Categorical) + + lf = pl.LazyFrame( + { + "cat": pl.Series(["a", "b", None], dtype=pl.Categorical), + } + ) + + schema = lf.select(pl.col("cat").cat.to_enum()).schema + assert schema == OrderedDict({"cat": pl.Enum(categories=["a", "b"])})