From 3dc421a5a802cf74b3feffa638d7aa69e97f8211 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 18 Oct 2024 13:49:05 +0200 Subject: [PATCH] fix: Fix enum scalar output (#19301) --- .../src/chunked_array/logical/enum_/mod.rs | 6 +-- .../src/chunked_array/ops/aggregate/mod.rs | 51 +++++++++++++++++-- py-polars/tests/unit/datatypes/test_enum.py | 18 +++++++ 3 files changed, 66 insertions(+), 9 deletions(-) diff --git a/crates/polars-core/src/chunked_array/logical/enum_/mod.rs b/crates/polars-core/src/chunked_array/logical/enum_/mod.rs index e143a59a7f7b..5279099f1cd7 100644 --- a/crates/polars-core/src/chunked_array/logical/enum_/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/enum_/mod.rs @@ -85,16 +85,12 @@ impl EnumChunkedBuilder { let length = arr.len() as IdxSize; let ca = unsafe { UInt32Chunked::new_with_dims( - Arc::new(Field::new( - self.name, - DataType::Enum(Some(self.rev.clone()), self.ordering), - )), + Arc::new(Field::new(self.name, DataType::UInt32)), vec![Box::new(arr)], length, null_count, ) }; - // SAFETY: keys and values are in bounds unsafe { CategoricalChunked::from_cats_and_rev_map_unchecked(ca, self.rev, true, self.ordering) diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index 5b3c0b5f53d7..0a059eb54274 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -11,6 +11,7 @@ use num_traits::{Float, One, ToPrimitive, Zero}; use polars_compute::float_sum; use polars_compute::min_max::MinMaxKernel; use polars_utils::min_max::MinMax; +use polars_utils::sync::SyncPtr; pub use quantile::*; pub use var::*; @@ -541,12 +542,54 @@ impl CategoricalChunked { #[cfg(feature = "dtype-categorical")] impl ChunkAggSeries for CategoricalChunked { fn min_reduce(&self) -> Scalar { - let av: AnyValue = self.min_categorical().into(); - Scalar::new(DataType::String, av.into_static()) + match self.dtype() { + DataType::Enum(r, _) => match self.physical().min() { + None => Scalar::new(self.dtype().clone(), AnyValue::Null), + Some(v) => { + let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else { + unreachable!() + }; + Scalar::new( + self.dtype().clone(), + AnyValue::EnumOwned( + v, + r.as_ref().unwrap().clone(), + SyncPtr::from_const(arr as *const _), + ), + ) + }, + }, + DataType::Categorical(_, _) => { + let av: AnyValue = self.min_categorical().into(); + Scalar::new(DataType::String, av.into_static()) + }, + _ => unreachable!(), + } } fn max_reduce(&self) -> Scalar { - let av: AnyValue = self.max_categorical().into(); - Scalar::new(DataType::String, av.into_static()) + match self.dtype() { + DataType::Enum(r, _) => match self.physical().max() { + None => Scalar::new(self.dtype().clone(), AnyValue::Null), + Some(v) => { + let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else { + unreachable!() + }; + Scalar::new( + self.dtype().clone(), + AnyValue::EnumOwned( + v, + r.as_ref().unwrap().clone(), + SyncPtr::from_const(arr as *const _), + ), + ) + }, + }, + DataType::Categorical(_, _) => { + let av: AnyValue = self.max_categorical().into(); + Scalar::new(DataType::String, av.into_static()) + }, + _ => unreachable!(), + } } } diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py index 9ad384cd1bab..9bd4a49ddd1d 100644 --- a/py-polars/tests/unit/datatypes/test_enum.py +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -536,3 +536,21 @@ def test_integer_cast_to_enum_15738(dt: pl.DataType) -> None: assert s.to_list() == ["a", "b", "c"] expected_s = pl.Series(["a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])) assert_series_equal(s, expected_s) + + +def test_enum_19269() -> None: + en = pl.Enum(["X", "Z", "Y"]) + df = pl.DataFrame( + {"test": pl.Series(["X", "Y", "Z"], dtype=en), "group": [1, 2, 2]} + ) + out = ( + df.group_by("group", maintain_order=True) + .agg(pl.col("test").mode()) + .select( + a=pl.col("test").list.max(), + b=pl.col("test").list.min(), + ) + ) + + assert out.to_dict(as_series=False) == {"a": ["X", "Y"], "b": ["X", "Z"]} + assert out.dtypes == [en, en]