Skip to content

Commit

Permalink
fix: Fix enum scalar output (pola-rs#19301)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 18, 2024
1 parent 8c90286 commit 3dc421a
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 9 deletions.
6 changes: 1 addition & 5 deletions crates/polars-core/src/chunked_array/logical/enum_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,12 @@ impl EnumChunkedBuilder {
let length = arr.len() as IdxSize;
let ca = unsafe {
UInt32Chunked::new_with_dims(
Arc::new(Field::new(
self.name,
DataType::Enum(Some(self.rev.clone()), self.ordering),
)),
Arc::new(Field::new(self.name, DataType::UInt32)),
vec![Box::new(arr)],
length,
null_count,
)
};

// SAFETY: keys and values are in bounds
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(ca, self.rev, true, self.ordering)
Expand Down
51 changes: 47 additions & 4 deletions crates/polars-core/src/chunked_array/ops/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use num_traits::{Float, One, ToPrimitive, Zero};
use polars_compute::float_sum;
use polars_compute::min_max::MinMaxKernel;
use polars_utils::min_max::MinMax;
use polars_utils::sync::SyncPtr;
pub use quantile::*;
pub use var::*;

Expand Down Expand Up @@ -541,12 +542,54 @@ impl CategoricalChunked {
#[cfg(feature = "dtype-categorical")]
impl ChunkAggSeries for CategoricalChunked {
fn min_reduce(&self) -> Scalar {
let av: AnyValue = self.min_categorical().into();
Scalar::new(DataType::String, av.into_static())
match self.dtype() {
DataType::Enum(r, _) => match self.physical().min() {
None => Scalar::new(self.dtype().clone(), AnyValue::Null),
Some(v) => {
let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else {
unreachable!()
};
Scalar::new(
self.dtype().clone(),
AnyValue::EnumOwned(
v,
r.as_ref().unwrap().clone(),
SyncPtr::from_const(arr as *const _),
),
)
},
},
DataType::Categorical(_, _) => {
let av: AnyValue = self.min_categorical().into();
Scalar::new(DataType::String, av.into_static())
},
_ => unreachable!(),
}
}
fn max_reduce(&self) -> Scalar {
let av: AnyValue = self.max_categorical().into();
Scalar::new(DataType::String, av.into_static())
match self.dtype() {
DataType::Enum(r, _) => match self.physical().max() {
None => Scalar::new(self.dtype().clone(), AnyValue::Null),
Some(v) => {
let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else {
unreachable!()
};
Scalar::new(
self.dtype().clone(),
AnyValue::EnumOwned(
v,
r.as_ref().unwrap().clone(),
SyncPtr::from_const(arr as *const _),
),
)
},
},
DataType::Categorical(_, _) => {
let av: AnyValue = self.max_categorical().into();
Scalar::new(DataType::String, av.into_static())
},
_ => unreachable!(),
}
}
}

Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,21 @@ def test_integer_cast_to_enum_15738(dt: pl.DataType) -> None:
assert s.to_list() == ["a", "b", "c"]
expected_s = pl.Series(["a", "b", "c"], dtype=pl.Enum(["a", "b", "c"]))
assert_series_equal(s, expected_s)


def test_enum_19269() -> None:
en = pl.Enum(["X", "Z", "Y"])
df = pl.DataFrame(
{"test": pl.Series(["X", "Y", "Z"], dtype=en), "group": [1, 2, 2]}
)
out = (
df.group_by("group", maintain_order=True)
.agg(pl.col("test").mode())
.select(
a=pl.col("test").list.max(),
b=pl.col("test").list.min(),
)
)

assert out.to_dict(as_series=False) == {"a": ["X", "Y"], "b": ["X", "Z"]}
assert out.dtypes == [en, en]

0 comments on commit 3dc421a

Please sign in to comment.