From c3e822cb6832c8b3da20b030bf8e19fb59dae9d9 Mon Sep 17 00:00:00 2001 From: Conor Kennedy <32619800+Vince7778@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:06:58 -0700 Subject: [PATCH] [FEAT] Add comparison of timestamps with same timezone (#2604) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Specifically, this means that columns with dtype e.g. `Timestamp(Milliseconds, Some("UTC"))` can be compared. However, this is a naïve comparison, as e.g. "UTC" and "+00:00" still cannot be compared. --- src/arrow2/src/array/ord.rs | 8 ++++++++ tests/dataframe/test_temporals.py | 18 ++++++++++++++++++ tests/series/test_comparisons.py | 7 +++++++ 3 files changed, 33 insertions(+) diff --git a/src/arrow2/src/array/ord.rs b/src/arrow2/src/array/ord.rs index 8b1d8318c1..6bf0d95126 100644 --- a/src/arrow2/src/array/ord.rs +++ b/src/arrow2/src/array/ord.rs @@ -211,6 +211,14 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result compare_primitives::(left, right), + (Timestamp(Second, x), Timestamp(Second, y)) + | (Timestamp(Millisecond, x), Timestamp(Millisecond, y)) + | (Timestamp(Microsecond, x), Timestamp(Microsecond, y)) + | (Timestamp(Nanosecond, x), Timestamp(Nanosecond, y)) + if x == y => + { + compare_primitives::(left, right) + } (Float32, Float32) => compare_f32(left, right), (Float64, Float64) => compare_f64(left, right), (Decimal(_, _), Decimal(_, _)) => compare_primitives::(left, right), diff --git a/tests/dataframe/test_temporals.py b/tests/dataframe/test_temporals.py index bf26622f1a..aff8f540b5 100644 --- a/tests/dataframe/test_temporals.py +++ b/tests/dataframe/test_temporals.py @@ -1,12 +1,15 @@ from __future__ import annotations +import itertools import tempfile from datetime import datetime, timedelta, timezone import pyarrow as pa import pytest +import pytz import daft +from daft import DataType, col PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0) @@ -265,3 +268,18 @@ def test_temporal_arithmetic_mismatch_granularity(t_timeunit, d_timeunit, timezo ]: with pytest.raises(ValueError): df.select(expression).collect() + + +@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2)) +@pytest.mark.parametrize("tz_repr", ["UTC", "+00:00"]) +def test_join_timestamp_same_timezone(tu1, tu2, tz_repr): + tz1 = [datetime(2022, 1, 1, tzinfo=pytz.utc), datetime(2022, 2, 1, tzinfo=pytz.utc)] + tz2 = [datetime(2022, 1, 2, tzinfo=pytz.utc), datetime(2022, 1, 1, tzinfo=pytz.utc)] + df1 = daft.from_pydict({"t": tz1, "x": [1, 2]}).with_column("t", col("t").cast(DataType.timestamp(tu1, tz_repr))) + df2 = daft.from_pydict({"t": tz2, "y": [3, 4]}).with_column("t", col("t").cast(DataType.timestamp(tu2, tz_repr))) + res = df1.join(df2, on="t") + assert res.to_pydict() == { + "t": [datetime(2022, 1, 1, tzinfo=pytz.utc)], + "x": [1], + "y": [4], + } diff --git a/tests/series/test_comparisons.py b/tests/series/test_comparisons.py index 21591db608..bfbfd2b236 100644 --- a/tests/series/test_comparisons.py +++ b/tests/series/test_comparisons.py @@ -801,6 +801,13 @@ def test_compare_timestamps_diff_tz(tu1, tu2): assert (tz1 == tz2).to_pylist() == [True] +@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2)) +def test_compare_lt_timestamps_same_tz(tu1, tu2): + tz1 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]).cast(DataType.timestamp(tu1, "UTC")) + tz2 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]).cast(DataType.timestamp(tu2, "UTC")) + assert (tz1 < tz2).to_pylist() == [False] + + @pytest.mark.parametrize("op", [operator.eq, operator.ne, operator.lt, operator.gt, operator.le, operator.ge]) def test_numeric_and_string_compare_raises_error(op): left = Series.from_pylist([1, 2, 3])