From a796300f4538921cf094f39caf9d7ce7ce0dce7a Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Thu, 3 Oct 2024 12:03:40 +0100 Subject: [PATCH] feat: Allow for rolling_*_by to use index count as window (#19071) --- .../src/chunkedarray/rolling_window/dispatch.rs | 11 ++++++++++- py-polars/polars/expr/expr.py | 8 ++++++++ .../tests/unit/operations/rolling/test_rolling.py | 14 ++++++++++++-- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs index 17b51ee2fec2..3f6ddce20f32 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs @@ -92,10 +92,19 @@ where by.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?, &None, ), + DataType::Int64 => ( + by.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?, + &None, + ), + DataType::Int32 | DataType::UInt64 | DataType::UInt32 => ( + by.cast(&DataType::Int64)? + .cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?, + &None, + ), dt => polars_bail!(InvalidOperation: "in `rolling_*_by` operation, `by` argument of dtype `{}` is not supported (expected `{}`)", dt, - "date/datetime"), + "Date/Datetime/Int64/Int32/UInt64/UInt32"), }; let ca = ca.rechunk(); let by = by.rechunk(); diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 505d71cafecd..98f638b0846a 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -6210,6 +6210,7 @@ def rolling_min_by( - 1mo (1 calendar month) - 1q (1 calendar quarter) - 1y (1 calendar year) + - 1i (1 index count) By "calendar day", we mean the corresponding time on the next day (which may not be 24 hours, due to daylight savings). Similarly for @@ -6331,6 +6332,7 @@ def rolling_max_by( - 1mo (1 calendar month) - 1q (1 calendar quarter) - 1y (1 calendar year) + - 1i (1 index count) By "calendar day", we mean the corresponding time on the next day (which may not be 24 hours, due to daylight savings). Similarly for @@ -6478,6 +6480,7 @@ def rolling_mean_by( - 1mo (1 calendar month) - 1q (1 calendar quarter) - 1y (1 calendar year) + - 1i (1 index count) By "calendar day", we mean the corresponding time on the next day (which may not be 24 hours, due to daylight savings). Similarly for @@ -6630,6 +6633,7 @@ def rolling_sum_by( - 1mo (1 calendar month) - 1q (1 calendar quarter) - 1y (1 calendar year) + - 1i (1 index count) By "calendar day", we mean the corresponding time on the next day (which may not be 24 hours, due to daylight savings). Similarly for @@ -6780,6 +6784,7 @@ def rolling_std_by( - 1mo (1 calendar month) - 1q (1 calendar quarter) - 1y (1 calendar year) + - 1i (1 index count) By "calendar day", we mean the corresponding time on the next day (which may not be 24 hours, due to daylight savings). Similarly for @@ -6936,6 +6941,7 @@ def rolling_var_by( - 1mo (1 calendar month) - 1q (1 calendar quarter) - 1y (1 calendar year) + - 1i (1 index count) By "calendar day", we mean the corresponding time on the next day (which may not be 24 hours, due to daylight savings). Similarly for @@ -7091,6 +7097,7 @@ def rolling_median_by( - 1mo (1 calendar month) - 1q (1 calendar quarter) - 1y (1 calendar year) + - 1i (1 index count) By "calendar day", we mean the corresponding time on the next day (which may not be 24 hours, due to daylight savings). Similarly for @@ -7220,6 +7227,7 @@ def rolling_quantile_by( - 1mo (1 calendar month) - 1q (1 calendar quarter) - 1y (1 calendar year) + - 1i (1 index count) By "calendar day", we mean the corresponding time on the next day (which may not be 24 hours, due to daylight savings). Similarly for diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 4ab0c18d4873..be88328ae40b 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -229,8 +229,10 @@ def test_rolling_crossing_dst( def test_rolling_by_invalid() -> None: - df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).sort("a") - msg = r"in `rolling_\*_by` operation, `by` argument of dtype `i64` is not supported" + df = pl.DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6]}, schema_overrides={"a": pl.Int16} + ).sort("a") + msg = "unsupported data type: i16 for `window_size`, expected UInt64, UInt32, Int64, Int32, Datetime, Date, Duration, or Time" with pytest.raises(InvalidOperationError, match=msg): df.select(pl.col("b").rolling_min_by("a", "2i")) df = pl.DataFrame({"a": [1, 2, 3], "b": [date(2020, 1, 1)] * 3}).sort("b") @@ -818,6 +820,14 @@ def test_rolling_by_date() -> None: assert_frame_equal(result, expected) +@pytest.mark.parametrize("dtype", [pl.Int64, pl.Int32, pl.UInt64, pl.UInt32]) +def test_rolling_by_integer(dtype: PolarsDataType) -> None: + df = pl.DataFrame({"val": [1, 2, 3]}).with_row_index() + result = df.with_columns(roll=pl.col("val").rolling_sum_by("index", "2i")) + expected = df.with_columns(roll=pl.Series([1, 3, 5])) + assert_frame_equal(result, expected) + + def test_rolling_nanoseconds_11003() -> None: df = pl.DataFrame( {