From 77b85291a16491a053d191bb79aa0a811c90d47c Mon Sep 17 00:00:00 2001 From: cojmeister <51219103+cojmeister@users.noreply.github.com> Date: Wed, 20 Mar 2024 11:16:47 +0200 Subject: [PATCH] feat(python): Infer `time_unit` in `pl.duration` when nanoseconds is specified (#14987) Co-authored-by: Stijn de Gooijer --- py-polars/polars/functions/as_datatype.py | 13 +++++++++--- .../functions/as_datatype/test_duration.py | 20 +++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/py-polars/polars/functions/as_datatype.py b/py-polars/polars/functions/as_datatype.py index 418800f478e4..60a1e6c78e31 100644 --- a/py-polars/polars/functions/as_datatype.py +++ b/py-polars/polars/functions/as_datatype.py @@ -182,7 +182,7 @@ def duration( milliseconds: Expr | str | int | None = None, microseconds: Expr | str | int | None = None, nanoseconds: Expr | str | int | None = None, - time_unit: TimeUnit = "us", + time_unit: TimeUnit | None = None, ) -> Expr: """ Create polars `Duration` from distinct time components. @@ -205,8 +205,10 @@ def duration( Number of microseconds. nanoseconds Number of nanoseconds. - time_unit : {'us', 'ms', 'ns'} - Time unit of the resulting expression. + time_unit : {None, 'us', 'ms', 'ns'} + Time unit of the resulting expression. If set to `None` (default), the time + unit will be inferred from the other inputs: `'ns'` if `nanoseconds` was + specified, `'us'` otherwise. Returns ------- @@ -299,6 +301,11 @@ def duration( microseconds = parse_as_expression(microseconds) if nanoseconds is not None: nanoseconds = parse_as_expression(nanoseconds) + if time_unit is None: + time_unit = "ns" + + if time_unit is None: + time_unit = "us" return wrap_expr( plr.duration( diff --git a/py-polars/tests/unit/functions/as_datatype/test_duration.py b/py-polars/tests/unit/functions/as_datatype/test_duration.py index cc50e7a5687a..6d80467dcd45 100644 --- a/py-polars/tests/unit/functions/as_datatype/test_duration.py +++ b/py-polars/tests/unit/functions/as_datatype/test_duration.py @@ -161,3 +161,23 @@ def test_duration_subseconds_us(time_unit: TimeUnit, ms: int, us: int, ns: int) milliseconds=ms, microseconds=us, nanoseconds=ns, time_unit=time_unit ) assert_frame_equal(pl.select(result), pl.select(expected)) + + +def test_duration_time_unit_ns() -> None: + result = pl.duration(milliseconds=4, microseconds=3_000, nanoseconds=10) + expected = pl.duration( + milliseconds=4, microseconds=3_000, nanoseconds=10, time_unit="ns" + ) + assert_frame_equal(pl.select(result), pl.select(expected)) + + +def test_duration_time_unit_us() -> None: + result = pl.duration(milliseconds=4, microseconds=3_000) + expected = pl.duration(milliseconds=4, microseconds=3_000, time_unit="us") + assert_frame_equal(pl.select(result), pl.select(expected)) + + +def test_duration_time_unit_ms() -> None: + result = pl.duration(milliseconds=4) + expected = pl.duration(milliseconds=4, time_unit="us") + assert_frame_equal(pl.select(result), pl.select(expected))