From 6208079f3a6d27c90b2aac5aa64d01258e826b50 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Wed, 18 Sep 2024 14:33:30 +0200 Subject: [PATCH 1/3] fix: Proper dtype casting for struct embedded categoricals --- crates/polars-core/src/datatypes/dtype.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 9eec5dff66f5..8924198bccec 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -764,7 +764,6 @@ impl Display for DataType { } pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult { - // TODO! add struct use DataType::*; Ok(match (left, right) { #[cfg(feature = "dtype-categorical")] @@ -794,6 +793,15 @@ pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult let merged = merge_dtypes(inner_l, inner_r)?; List(Box::new(merged)) }, + (Struct(inner_l), Struct(inner_r)) => { + polars_ensure!(inner_l.len() == inner_r.len(), ComputeError: "cannot combine different structs"); + let fields = inner_l.iter().zip(inner_r.iter()).map(|(l, r)| { + polars_ensure!(l.name() == r.name(), ComputeError: "cannot combine different structs"); + let merged = merge_dtypes(l.dtype(), r.dtype())?; + Ok(Field::new(l.name().clone(), merged)) + }).collect::>>()?; + Struct(fields) + }, #[cfg(feature = "dtype-array")] (Array(inner_l, width_l), Array(inner_r, width_r)) => { polars_ensure!(width_l == width_r, ComputeError: "widths of FixedSizeWidth Series are not equal"); From d3cb162bc4a2a8479233864abad21c3d88668e00 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Wed, 18 Sep 2024 14:59:49 +0200 Subject: [PATCH 2/3] test and better error messages --- crates/polars-core/src/datatypes/dtype.rs | 5 +++-- .../tests/unit/datatypes/test_categorical.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 8924198bccec..956d055a52c2 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -793,10 +793,11 @@ pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult let merged = merge_dtypes(inner_l, inner_r)?; List(Box::new(merged)) }, + #[cfg(feature = "dtype-struct")] (Struct(inner_l), Struct(inner_r)) => { - polars_ensure!(inner_l.len() == inner_r.len(), ComputeError: "cannot combine different structs"); + polars_ensure!(inner_l.len() == inner_r.len(), ComputeError: "cannot combine structs with differing amounts of fields ({} != {})", inner_l.len(), inner_r.len()); let fields = inner_l.iter().zip(inner_r.iter()).map(|(l, r)| { - polars_ensure!(l.name() == r.name(), ComputeError: "cannot combine different structs"); + polars_ensure!(l.name() == r.name(), ComputeError: "cannot combine structs with different fields ({} != {})", l.name(), r.name()); let merged = merge_dtypes(l.dtype(), r.dtype())?; Ok(Field::new(l.name().clone(), merged)) }).collect::>>()?; diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index b898c7c07999..209285a8fac5 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -7,6 +7,8 @@ import pytest +from hypothesis import given + import polars as pl from polars import StringCache from polars.exceptions import ( @@ -845,3 +847,19 @@ def test_get_cat_categories_multiple_chunks() -> None: ) df_cat = df.lazy().select(pl.col("e").cat.get_categories()).collect() assert len(df_cat) == 2 + + +@pytest.mark.parametrize( + "f", [ + lambda x: (pl.List(pl.Categorical), [x]), + lambda x: (pl.Struct({ 'a': pl.Categorical }), { 'a': x }), + ] +) +def test_nested_categorical_concat(f: Callable[[str], tuple[pl.DataType, list[str] | dict[str, str]]]) -> None: + dt, va = f("a") + _, vb = f("b") + a = pl.DataFrame({ 'x': [va] }, schema={ 'x': dt }) + b = pl.DataFrame({ 'x': [vb] }, schema={ 'x': dt }) + + with pytest.raises(pl.exceptions.StringCacheMismatchError): + stack = pl.concat([a, b]) From 74a1fb706ef860295723d120508f36f447a2bb8e Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Wed, 18 Sep 2024 15:00:50 +0200 Subject: [PATCH 3/3] pyfmt --- .../tests/unit/datatypes/test_categorical.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index 209285a8fac5..c5888abee67e 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -7,8 +7,6 @@ import pytest -from hypothesis import given - import polars as pl from polars import StringCache from polars.exceptions import ( @@ -850,16 +848,19 @@ def test_get_cat_categories_multiple_chunks() -> None: @pytest.mark.parametrize( - "f", [ + "f", + [ lambda x: (pl.List(pl.Categorical), [x]), - lambda x: (pl.Struct({ 'a': pl.Categorical }), { 'a': x }), - ] + lambda x: (pl.Struct({"a": pl.Categorical}), {"a": x}), + ], ) -def test_nested_categorical_concat(f: Callable[[str], tuple[pl.DataType, list[str] | dict[str, str]]]) -> None: +def test_nested_categorical_concat( + f: Callable[[str], tuple[pl.DataType, list[str] | dict[str, str]]], +) -> None: dt, va = f("a") _, vb = f("b") - a = pl.DataFrame({ 'x': [va] }, schema={ 'x': dt }) - b = pl.DataFrame({ 'x': [vb] }, schema={ 'x': dt }) + a = pl.DataFrame({"x": [va]}, schema={"x": dt}) + b = pl.DataFrame({"x": [vb]}, schema={"x": dt}) with pytest.raises(pl.exceptions.StringCacheMismatchError): - stack = pl.concat([a, b]) + pl.concat([a, b])