Skip to content

Commit

Permalink
[FEAT] Add comparison of timestamps with same timezone (#2604)
Browse files Browse the repository at this point in the history
Specifically, this means that columns with dtype e.g.
`Timestamp(Milliseconds, Some("UTC"))` can be compared. However, this is
a naïve comparison, as e.g. "UTC" and "+00:00" still cannot be compared.
  • Loading branch information
Vince7778 authored Aug 1, 2024
1 parent 73138c9 commit c3e822c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/arrow2/src/array/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result<DynComparato
| (Duration(Millisecond), Duration(Millisecond))
| (Duration(Microsecond), Duration(Microsecond))
| (Duration(Nanosecond), Duration(Nanosecond)) => compare_primitives::<i64>(left, right),
(Timestamp(Second, x), Timestamp(Second, y))
| (Timestamp(Millisecond, x), Timestamp(Millisecond, y))
| (Timestamp(Microsecond, x), Timestamp(Microsecond, y))
| (Timestamp(Nanosecond, x), Timestamp(Nanosecond, y))
if x == y =>
{
compare_primitives::<i64>(left, right)
}
(Float32, Float32) => compare_f32(left, right),
(Float64, Float64) => compare_f64(left, right),
(Decimal(_, _), Decimal(_, _)) => compare_primitives::<i128>(left, right),
Expand Down
18 changes: 18 additions & 0 deletions tests/dataframe/test_temporals.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import itertools
import tempfile
from datetime import datetime, timedelta, timezone

import pyarrow as pa
import pytest
import pytz

import daft
from daft import DataType, col

PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0)

Expand Down Expand Up @@ -265,3 +268,18 @@ def test_temporal_arithmetic_mismatch_granularity(t_timeunit, d_timeunit, timezo
]:
with pytest.raises(ValueError):
df.select(expression).collect()


@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2))
@pytest.mark.parametrize("tz_repr", ["UTC", "+00:00"])
def test_join_timestamp_same_timezone(tu1, tu2, tz_repr):
tz1 = [datetime(2022, 1, 1, tzinfo=pytz.utc), datetime(2022, 2, 1, tzinfo=pytz.utc)]
tz2 = [datetime(2022, 1, 2, tzinfo=pytz.utc), datetime(2022, 1, 1, tzinfo=pytz.utc)]
df1 = daft.from_pydict({"t": tz1, "x": [1, 2]}).with_column("t", col("t").cast(DataType.timestamp(tu1, tz_repr)))
df2 = daft.from_pydict({"t": tz2, "y": [3, 4]}).with_column("t", col("t").cast(DataType.timestamp(tu2, tz_repr)))
res = df1.join(df2, on="t")
assert res.to_pydict() == {
"t": [datetime(2022, 1, 1, tzinfo=pytz.utc)],
"x": [1],
"y": [4],
}
7 changes: 7 additions & 0 deletions tests/series/test_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,13 @@ def test_compare_timestamps_diff_tz(tu1, tu2):
assert (tz1 == tz2).to_pylist() == [True]


@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2))
def test_compare_lt_timestamps_same_tz(tu1, tu2):
tz1 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]).cast(DataType.timestamp(tu1, "UTC"))
tz2 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]).cast(DataType.timestamp(tu2, "UTC"))
assert (tz1 < tz2).to_pylist() == [False]


@pytest.mark.parametrize("op", [operator.eq, operator.ne, operator.lt, operator.gt, operator.le, operator.ge])
def test_numeric_and_string_compare_raises_error(op):
left = Series.from_pylist([1, 2, 3])
Expand Down

0 comments on commit c3e822c

Please sign in to comment.