Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix list.mean and list.median returning Float64 for temporal types #21144

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions crates/polars-ops/src/chunked_array/list/dispersion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,22 @@ pub(super) fn median_with_nulls(ca: &ListChunked) -> Series {
.with_name(ca.name().clone());
out.into_series()
},
#[cfg(feature = "dtype-duration")]
DataType::Duration(tu) => {
#[cfg(feature = "dtype-datetime")]
DataType::Date => {
const MS_IN_DAY: i64 = 86_400_000;
let out: Int64Chunked = ca
.apply_amortized_generic(|s| {
s.and_then(|s| s.as_ref().median().map(|v| (v * (MS_IN_DAY as f64)) as i64))
})
.with_name(ca.name().clone());
out.into_datetime(TimeUnit::Milliseconds, None)
.into_series()
},
dt if dt.is_temporal() => {
let out: Int64Chunked = ca
.apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as i64)))
.with_name(ca.name().clone());
out.into_duration(*tu).into_series()
out.cast(dt).unwrap()
},
_ => {
let out: Float64Chunked = ca
Expand Down
17 changes: 17 additions & 0 deletions crates/polars-ops/src/chunked_array/list/sum_mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,23 @@ pub(super) fn mean_with_nulls(ca: &ListChunked) -> Series {
.with_name(ca.name().clone());
out.into_series()
},
#[cfg(feature = "dtype-datetime")]
DataType::Date => {
const MS_IN_DAY: i64 = 86_400_000;
let out: Int64Chunked = ca
.apply_amortized_generic(|s| {
s.and_then(|s| s.as_ref().mean().map(|v| (v * (MS_IN_DAY as f64)) as i64))
})
.with_name(ca.name().clone());
out.into_datetime(TimeUnit::Milliseconds, None)
.into_series()
},
dt if dt.is_temporal() => {
let out: Int64Chunked = ca
.apply_amortized_generic(|s| s.and_then(|s| s.as_ref().mean().map(|v| v as i64)))
.with_name(ca.name().clone());
out.cast(dt).unwrap()
},
_ => {
let out: Float64Chunked = ca
.apply_amortized_generic(|s| s.and_then(|s| s.as_ref().mean()))
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ impl ListFunction {
Sum => mapper.nested_sum_type(),
Min => mapper.map_to_list_and_array_inner_dtype(),
Max => mapper.map_to_list_and_array_inner_dtype(),
Mean => mapper.with_dtype(DataType::Float64),
Median => mapper.map_to_float_dtype(),
Mean => mapper.nested_mean_median_type(),
Median => mapper.nested_mean_median_type(),
Std(_) => mapper.map_to_float_dtype(), // Need to also have this sometimes marked as float32 or duration..
Var(_) => mapper.map_to_float_dtype(),
ArgMin => mapper.with_dtype(IDX_DTYPE),
Expand Down
20 changes: 20 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,26 @@ impl<'a> FieldsMapper<'a> {
Ok(first)
}

pub fn nested_mean_median_type(&self) -> PolarsResult<Field> {
let mut first = self.fields[0].clone();
use DataType::*;
let dt = first
.dtype()
.inner_dtype()
.cloned()
.unwrap_or_else(|| Unknown(Default::default()));

let new_dt = match dt {
#[cfg(feature = "dtype-datetime")]
Date => Datetime(TimeUnit::Milliseconds, None),
dt if dt.is_temporal() => dt,
Float32 => Float32,
_ => Float64,
};
first.coerce(new_dt);
Ok(first)
}

pub(super) fn pow_dtype(&self) -> PolarsResult<Field> {
let base_dtype = self.fields[0].dtype();
let exponent_dtype = self.fields[1].dtype();
Expand Down
12 changes: 11 additions & 1 deletion py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import polars as pl
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.conftest import NUMERIC_DTYPES
from tests.unit.conftest import NUMERIC_DTYPES, TEMPORAL_DTYPES

if TYPE_CHECKING:
from polars._typing import PolarsDataType
Expand Down Expand Up @@ -864,3 +864,13 @@ def tc(a: list[Any], b: list[Any]) -> None:
tc([[1], []], [[], [1]])
tc([[2, 1]], [[2, 1]])
tc([[2, 1], [1, 2]], [[1, 2], [2, 1]])


@pytest.mark.parametrize("inner_dtype", TEMPORAL_DTYPES)
@pytest.mark.parametrize("agg", ["min", "max", "mean", "median"])
def test_list_agg_temporal(inner_dtype: PolarsDataType, agg: str) -> None:
lf = pl.LazyFrame({"a": [[1, 3]]}, schema={"a": pl.List(inner_dtype)})
result = lf.select(getattr(pl.col("a").list, agg)())
expected = lf.select(getattr(pl.col("a").explode(), agg)())
assert result.collect_schema() == expected.collect_schema()
assert_frame_equal(result.collect(), expected.collect())
Loading