From 0d4f8e7e3cdee6dd203f0043615762d1a33af313 Mon Sep 17 00:00:00 2001 From: Marshall Date: Thu, 26 Dec 2024 05:05:12 -0500 Subject: [PATCH] fix: Return correct schema for `sum_horizontal` with boolean dtype (#20459) --- .../src/dsl/function_expr/schema.rs | 12 ++++--- .../operations/aggregation/test_horizontal.py | 32 +++++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 018a3b0207d1..106a05c2041f 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -330,11 +330,13 @@ impl FunctionExpr { MaxHorizontal => mapper.map_to_supertype(), MinHorizontal => mapper.map_to_supertype(), SumHorizontal { .. } => { - if mapper.fields[0].dtype() == &DataType::Boolean { - mapper.with_dtype(DataType::UInt32) - } else { - mapper.map_to_supertype() - } + mapper.map_to_supertype().map(|mut f| { + match f.dtype { + // Booleans sum to UInt32. + DataType::Boolean => { f.dtype = DataType::UInt32; f}, + _ => f, + } + }) }, MeanHorizontal { .. } => mapper.map_to_float_dtype(), #[cfg(feature = "ewma")] diff --git a/py-polars/tests/unit/operations/aggregation/test_horizontal.py b/py-polars/tests/unit/operations/aggregation/test_horizontal.py index 3ee5514a4563..ea93b1da97fd 100644 --- a/py-polars/tests/unit/operations/aggregation/test_horizontal.py +++ b/py-polars/tests/unit/operations/aggregation/test_horizontal.py @@ -524,3 +524,35 @@ def test_expected_horizontal_dtype_errors(horizontal_func: type[pl.Expr]) -> Non pl.col("cole"), ) ) + + +def test_horizontal_sum_boolean_with_null() -> None: + lf = pl.LazyFrame( + { + "null": [None, None], + "bool": [True, False], + } + ) + + out = lf.select( + pl.sum_horizontal("null", "bool").alias("null_first"), + pl.sum_horizontal("bool", "null").alias("bool_first"), + ) + + expected_schema = pl.Schema( + { + "null_first": pl.UInt32, + "bool_first": pl.UInt32, + } + ) + + assert out.collect_schema() == expected_schema + + expected_df = pl.DataFrame( + { + "null_first": pl.Series([1, 0], dtype=pl.UInt32), + "bool_first": pl.Series([1, 0], dtype=pl.UInt32), + } + ) + + assert_frame_equal(out.collect(), expected_df)