Skip to content

Commit

Permalink
fix: Merge categorical rev-map in unpivot (#19313)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 19, 2024
1 parent f88bd6a commit 46edcd8
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 13 deletions.
5 changes: 3 additions & 2 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ impl PartialEq for DataType {
use DataType::*;
{
match (self, other) {
// Don't include rev maps in comparisons
#[cfg(feature = "dtype-categorical")]
(Categorical(_, _), Categorical(_, _)) => true,
// Don't include rev maps in comparisons
// TODO: include ordering in comparison
(Categorical(_, _ordering_l), Categorical(_, _ordering_r)) => true,
#[cfg(feature = "dtype-categorical")]
// None means select all Enum dtypes. This is for operation `pl.col(pl.Enum)`
(Enum(None, _), Enum(_, _)) | (Enum(_, _), Enum(None, _)) => true,
Expand Down
53 changes: 53 additions & 0 deletions crates/polars-core/src/utils/supertype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,56 @@ fn materialize_smallest_dyn_int(v: i128) -> AnyValue<'static> {
},
}
}

pub fn merge_dtypes_many<I: IntoIterator<Item = D> + Clone, D: AsRef<DataType>>(
into_iter: I,
) -> PolarsResult<DataType> {
let mut iter = into_iter.clone().into_iter();

let mut st = iter
.next()
.ok_or_else(|| polars_err!(ComputeError: "expect at least 1 dtype"))
.map(|d| d.as_ref().clone())?;

for d in iter {
st = try_get_supertype(d.as_ref(), &st)?;
}

match st {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(Some(_), ordering) => {
// This merges the global rev maps with linear complexity.
// If we do a binary reduce, it would be quadratic.
let mut iter = into_iter.into_iter();
let first_dt = iter.next().unwrap();
let first_dt = first_dt.as_ref();
let DataType::Categorical(Some(rm), _) = first_dt else {
unreachable!()
};

let mut merger = GlobalRevMapMerger::new(rm.clone());

for d in iter {
if let DataType::Categorical(Some(rm), _) = d.as_ref() {
merger.merge_map(rm)?
}
}
let rev_map = merger.finish();

Ok(DataType::Categorical(Some(rev_map), ordering))
},
// This would be quadratic if we do this with the binary `merge_dtypes`.
DataType::List(inner) if inner.contains_categoricals() => {
polars_bail!(ComputeError: "merging nested categoricals not yet supported")
},
#[cfg(feature = "dtype-array")]
DataType::Array(inner, _) if inner.contains_categoricals() => {
polars_bail!(ComputeError: "merging nested categoricals not yet supported")
},
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) if fields.iter().any(|f| f.dtype().contains_categoricals()) => {
polars_bail!(ComputeError: "merging nested categoricals not yet supported")
},
_ => Ok(st),
}
}
21 changes: 10 additions & 11 deletions crates/polars-ops/src/frame/pivot/unpivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use polars_core::datatypes::{DataType, PlSmallStr};
use polars_core::frame::column::Column;
use polars_core::frame::DataFrame;
use polars_core::prelude::{IntoVec, Series, UnpivotArgsIR};
use polars_core::utils::try_get_supertype;
use polars_core::utils::merge_dtypes_many;
use polars_error::{polars_err, PolarsResult};
use polars_utils::aliases::PlHashSet;

Expand Down Expand Up @@ -104,9 +104,9 @@ pub trait UnpivotDF: IntoDf {

let len = self_.height();

// if value vars is empty we take all columns that are not in id_vars.
// If value vars is empty we take all columns that are not in id_vars.
if on.is_empty() {
// return empty frame if there are no columns available to use as value vars
// Return empty frame if there are no columns available to use as value vars.
if index.len() == self_.width() {
let variable_col = Column::new_empty(variable_name, &DataType::String);
let value_col = Column::new_empty(value_name, &DataType::Null);
Expand All @@ -133,15 +133,14 @@ pub trait UnpivotDF: IntoDf {
.collect();
}

// values will all be placed in single column, so we must find their supertype
// Values will all be placed in single column, so we must find their supertype
let schema = self_.schema();
let mut iter = on
let dtypes = on
.iter()
.map(|v| schema.get(v).ok_or_else(|| polars_err!(col_not_found = v)));
let mut st = iter.next().unwrap()?.clone();
for dt in iter {
st = try_get_supertype(&st, dt?)?;
}
.map(|v| schema.get(v).ok_or_else(|| polars_err!(col_not_found = v)))
.collect::<PolarsResult<Vec<_>>>()?;

let st = merge_dtypes_many(dtypes.iter())?;

// The column name of the variable that is unpivoted
let mut variable_col = MutablePlString::with_capacity(len * on.len() + 1);
Expand All @@ -167,7 +166,7 @@ pub trait UnpivotDF: IntoDf {
let (pos, _name, _dtype) = schema.try_get_full(value_column_name)?;
let col = &columns[pos];
let value_col = col.cast(&st).map_err(
|_| polars_err!(InvalidOperation: "'unpivot' not supported for dtype: {}", col.dtype()),
|_| polars_err!(InvalidOperation: "'unpivot' not supported for dtype: {}\n\nConsider casting to String.", col.dtype()),
)?;
values.extend_from_slice(value_col.as_materialized_series().chunks())
}
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/operations/test_unpivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import polars as pl
import polars.selectors as cs
from polars import StringCache
from polars.testing import assert_frame_equal


Expand Down Expand Up @@ -94,3 +95,21 @@ def test_unpivot_empty_18170() -> None:
assert pl.DataFrame().unpivot().schema == pl.Schema(
{"variable": pl.String(), "value": pl.Null()}
)


@StringCache()
def test_unpivot_categorical_global() -> None:
df = pl.DataFrame(
{
"index": [0, 1],
"1": pl.Series(["a", "b"], dtype=pl.Categorical),
"2": pl.Series(["b", "c"], dtype=pl.Categorical),
}
)
out = df.unpivot(["1", "2"], index="index")
assert out.dtypes == [pl.Int64, pl.String, pl.Categorical(ordering="physical")]
assert out.to_dict(as_series=False) == {
"index": [0, 1, 0, 1],
"variable": ["1", "1", "2", "2"],
"value": ["a", "b", "b", "c"],
}

0 comments on commit 46edcd8

Please sign in to comment.