Skip to content

Commit

Permalink
fix: Return correct schema for sum_horizontal with boolean dtype (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller authored Dec 26, 2024
1 parent aaacdbe commit 0d4f8e7
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
12 changes: 7 additions & 5 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,13 @@ impl FunctionExpr {
MaxHorizontal => mapper.map_to_supertype(),
MinHorizontal => mapper.map_to_supertype(),
SumHorizontal { .. } => {
if mapper.fields[0].dtype() == &DataType::Boolean {
mapper.with_dtype(DataType::UInt32)
} else {
mapper.map_to_supertype()
}
mapper.map_to_supertype().map(|mut f| {
match f.dtype {
// Booleans sum to UInt32.
DataType::Boolean => { f.dtype = DataType::UInt32; f},
_ => f,
}
})
},
MeanHorizontal { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "ewma")]
Expand Down
32 changes: 32 additions & 0 deletions py-polars/tests/unit/operations/aggregation/test_horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,35 @@ def test_expected_horizontal_dtype_errors(horizontal_func: type[pl.Expr]) -> Non
pl.col("cole"),
)
)


def test_horizontal_sum_boolean_with_null() -> None:
lf = pl.LazyFrame(
{
"null": [None, None],
"bool": [True, False],
}
)

out = lf.select(
pl.sum_horizontal("null", "bool").alias("null_first"),
pl.sum_horizontal("bool", "null").alias("bool_first"),
)

expected_schema = pl.Schema(
{
"null_first": pl.UInt32,
"bool_first": pl.UInt32,
}
)

assert out.collect_schema() == expected_schema

expected_df = pl.DataFrame(
{
"null_first": pl.Series([1, 0], dtype=pl.UInt32),
"bool_first": pl.Series([1, 0], dtype=pl.UInt32),
}
)

assert_frame_equal(out.collect(), expected_df)

0 comments on commit 0d4f8e7

Please sign in to comment.