From 46edcd8c4a6e0321199358c83fd8f2b9c3056f8f Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sat, 19 Oct 2024 12:27:27 +0200 Subject: [PATCH] fix: Merge categorical rev-map in `unpivot` (#19313) --- crates/polars-core/src/datatypes/dtype.rs | 5 +- crates/polars-core/src/utils/supertype.rs | 53 +++++++++++++++++++ crates/polars-ops/src/frame/pivot/unpivot.rs | 21 ++++---- .../tests/unit/operations/test_unpivot.py | 19 +++++++ 4 files changed, 85 insertions(+), 13 deletions(-) diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index e18dd9026a4d..02cf360f3bf5 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -115,9 +115,10 @@ impl PartialEq for DataType { use DataType::*; { match (self, other) { - // Don't include rev maps in comparisons #[cfg(feature = "dtype-categorical")] - (Categorical(_, _), Categorical(_, _)) => true, + // Don't include rev maps in comparisons + // TODO: include ordering in comparison + (Categorical(_, _ordering_l), Categorical(_, _ordering_r)) => true, #[cfg(feature = "dtype-categorical")] // None means select all Enum dtypes. This is for operation `pl.col(pl.Enum)` (Enum(None, _), Enum(_, _)) | (Enum(_, _), Enum(None, _)) => true, diff --git a/crates/polars-core/src/utils/supertype.rs b/crates/polars-core/src/utils/supertype.rs index 027e85886793..18b8c9ddd00a 100644 --- a/crates/polars-core/src/utils/supertype.rs +++ b/crates/polars-core/src/utils/supertype.rs @@ -498,3 +498,56 @@ fn materialize_smallest_dyn_int(v: i128) -> AnyValue<'static> { }, } } + +pub fn merge_dtypes_many + Clone, D: AsRef>( + into_iter: I, +) -> PolarsResult { + let mut iter = into_iter.clone().into_iter(); + + let mut st = iter + .next() + .ok_or_else(|| polars_err!(ComputeError: "expect at least 1 dtype")) + .map(|d| d.as_ref().clone())?; + + for d in iter { + st = try_get_supertype(d.as_ref(), &st)?; + } + + match st { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(Some(_), ordering) => { + // This merges the global rev maps with linear complexity. + // If we do a binary reduce, it would be quadratic. + let mut iter = into_iter.into_iter(); + let first_dt = iter.next().unwrap(); + let first_dt = first_dt.as_ref(); + let DataType::Categorical(Some(rm), _) = first_dt else { + unreachable!() + }; + + let mut merger = GlobalRevMapMerger::new(rm.clone()); + + for d in iter { + if let DataType::Categorical(Some(rm), _) = d.as_ref() { + merger.merge_map(rm)? + } + } + let rev_map = merger.finish(); + + Ok(DataType::Categorical(Some(rev_map), ordering)) + }, + // This would be quadratic if we do this with the binary `merge_dtypes`. + DataType::List(inner) if inner.contains_categoricals() => { + polars_bail!(ComputeError: "merging nested categoricals not yet supported") + }, + #[cfg(feature = "dtype-array")] + DataType::Array(inner, _) if inner.contains_categoricals() => { + polars_bail!(ComputeError: "merging nested categoricals not yet supported") + }, + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) if fields.iter().any(|f| f.dtype().contains_categoricals()) => { + polars_bail!(ComputeError: "merging nested categoricals not yet supported") + }, + _ => Ok(st), + } +} diff --git a/crates/polars-ops/src/frame/pivot/unpivot.rs b/crates/polars-ops/src/frame/pivot/unpivot.rs index 60f5bff8eae9..49eeaeba4498 100644 --- a/crates/polars-ops/src/frame/pivot/unpivot.rs +++ b/crates/polars-ops/src/frame/pivot/unpivot.rs @@ -4,7 +4,7 @@ use polars_core::datatypes::{DataType, PlSmallStr}; use polars_core::frame::column::Column; use polars_core::frame::DataFrame; use polars_core::prelude::{IntoVec, Series, UnpivotArgsIR}; -use polars_core::utils::try_get_supertype; +use polars_core::utils::merge_dtypes_many; use polars_error::{polars_err, PolarsResult}; use polars_utils::aliases::PlHashSet; @@ -104,9 +104,9 @@ pub trait UnpivotDF: IntoDf { let len = self_.height(); - // if value vars is empty we take all columns that are not in id_vars. + // If value vars is empty we take all columns that are not in id_vars. if on.is_empty() { - // return empty frame if there are no columns available to use as value vars + // Return empty frame if there are no columns available to use as value vars. if index.len() == self_.width() { let variable_col = Column::new_empty(variable_name, &DataType::String); let value_col = Column::new_empty(value_name, &DataType::Null); @@ -133,15 +133,14 @@ pub trait UnpivotDF: IntoDf { .collect(); } - // values will all be placed in single column, so we must find their supertype + // Values will all be placed in single column, so we must find their supertype let schema = self_.schema(); - let mut iter = on + let dtypes = on .iter() - .map(|v| schema.get(v).ok_or_else(|| polars_err!(col_not_found = v))); - let mut st = iter.next().unwrap()?.clone(); - for dt in iter { - st = try_get_supertype(&st, dt?)?; - } + .map(|v| schema.get(v).ok_or_else(|| polars_err!(col_not_found = v))) + .collect::>>()?; + + let st = merge_dtypes_many(dtypes.iter())?; // The column name of the variable that is unpivoted let mut variable_col = MutablePlString::with_capacity(len * on.len() + 1); @@ -167,7 +166,7 @@ pub trait UnpivotDF: IntoDf { let (pos, _name, _dtype) = schema.try_get_full(value_column_name)?; let col = &columns[pos]; let value_col = col.cast(&st).map_err( - |_| polars_err!(InvalidOperation: "'unpivot' not supported for dtype: {}", col.dtype()), + |_| polars_err!(InvalidOperation: "'unpivot' not supported for dtype: {}\n\nConsider casting to String.", col.dtype()), )?; values.extend_from_slice(value_col.as_materialized_series().chunks()) } diff --git a/py-polars/tests/unit/operations/test_unpivot.py b/py-polars/tests/unit/operations/test_unpivot.py index ada642c294ae..434c2fdc3af9 100644 --- a/py-polars/tests/unit/operations/test_unpivot.py +++ b/py-polars/tests/unit/operations/test_unpivot.py @@ -2,6 +2,7 @@ import polars as pl import polars.selectors as cs +from polars import StringCache from polars.testing import assert_frame_equal @@ -94,3 +95,21 @@ def test_unpivot_empty_18170() -> None: assert pl.DataFrame().unpivot().schema == pl.Schema( {"variable": pl.String(), "value": pl.Null()} ) + + +@StringCache() +def test_unpivot_categorical_global() -> None: + df = pl.DataFrame( + { + "index": [0, 1], + "1": pl.Series(["a", "b"], dtype=pl.Categorical), + "2": pl.Series(["b", "c"], dtype=pl.Categorical), + } + ) + out = df.unpivot(["1", "2"], index="index") + assert out.dtypes == [pl.Int64, pl.String, pl.Categorical(ordering="physical")] + assert out.to_dict(as_series=False) == { + "index": [0, 1, 0, 1], + "variable": ["1", "1", "2", "2"], + "value": ["a", "b", "b", "c"], + }