Skip to content

Commit

Permalink
feat: allow for rolling_*_by to use index count as window
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Oct 2, 2024
1 parent eb615e0 commit a0ed213
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
11 changes: 10 additions & 1 deletion crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,19 @@ where
by.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?,
&None,
),
DataType::Int64 => (
by.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,
&None,
),
DataType::Int32 | DataType::UInt64 | DataType::UInt32 => (
by.cast(&DataType::Int64)?
.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,
&None,
),
dt => polars_bail!(InvalidOperation:
"in `rolling_*_by` operation, `by` argument of dtype `{}` is not supported (expected `{}`)",
dt,
"date/datetime"),
"Date/Datetime/Int64/Int32/UInt64/UInt32"),
};
let ca = ca.rechunk();
let by = by.rechunk();
Expand Down
8 changes: 8 additions & 0 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6210,6 +6210,7 @@ def rolling_min_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)
By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -6331,6 +6332,7 @@ def rolling_max_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)
By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -6478,6 +6480,7 @@ def rolling_mean_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)
By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -6630,6 +6633,7 @@ def rolling_sum_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)
By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -6780,6 +6784,7 @@ def rolling_std_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)
By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -6936,6 +6941,7 @@ def rolling_var_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)
By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -7091,6 +7097,7 @@ def rolling_median_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)
By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -7220,6 +7227,7 @@ def rolling_quantile_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)
By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down
14 changes: 12 additions & 2 deletions py-polars/tests/unit/operations/rolling/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ def test_rolling_crossing_dst(


def test_rolling_by_invalid() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).sort("a")
msg = r"in `rolling_\*_by` operation, `by` argument of dtype `i64` is not supported"
df = pl.DataFrame(
{"a": [1, 2, 3], "b": [4, 5, 6]}, schema_overrides={"a": pl.Int16}
).sort("a")
msg = "unsupported data type: i16 for `window_size`, expected UInt64, UInt32, Int64, Int32, Datetime, Date, Duration, or Time"
with pytest.raises(InvalidOperationError, match=msg):
df.select(pl.col("b").rolling_min_by("a", "2i"))
df = pl.DataFrame({"a": [1, 2, 3], "b": [date(2020, 1, 1)] * 3}).sort("b")
Expand Down Expand Up @@ -818,6 +820,14 @@ def test_rolling_by_date() -> None:
assert_frame_equal(result, expected)


@pytest.mark.parametrize("dtype", [pl.Int64, pl.Int32, pl.UInt64, pl.UInt32])
def test_rolling_by_integer(dtype: PolarsDataType) -> None:
df = pl.DataFrame({"val": [1, 2, 3]}).with_row_index()
result = df.with_columns(roll=pl.col("val").rolling_sum_by("index", "2i"))
expected = df.with_columns(roll=pl.Series([1, 3, 5]))
assert_frame_equal(result, expected)


def test_rolling_nanoseconds_11003() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit a0ed213

Please sign in to comment.