Skip to content

Commit

Permalink
[data] Add better support for udf returns from list of datetime objec…
Browse files Browse the repository at this point in the history
…ts (#46762)

This PR adds better support for the translation of lists of datetime objects. It ensures that if pyarrow blocks are used then the datetime objects will be loaded as timestamps.


---------

Signed-off-by: Matthew Owen <[email protected]>
  • Loading branch information
omatthew98 authored Jul 26, 2024
1 parent 8d2b459 commit 69f3218
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
29 changes: 29 additions & 0 deletions python/ray/data/_internal/numpy_support.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
from datetime import datetime
from typing import Any, Dict, List, Union

import numpy as np
Expand Down Expand Up @@ -40,6 +41,31 @@ def validate_numpy_batch(batch: Union[Dict[str, np.ndarray], Dict[str, list]]) -
)


def _detect_highest_datetime_precision(datetime_list: List[datetime]) -> str:
highest_precision = "D"

for dt in datetime_list:
if dt.microsecond != 0 and dt.microsecond % 1000 != 0:
highest_precision = "us"
break
elif dt.microsecond != 0 and dt.microsecond % 1000 == 0:
highest_precision = "ms"
elif dt.hour != 0 or dt.minute != 0 or dt.second != 0:
# pyarrow does not support h or m, use s for those cases too
highest_precision = "s"

return highest_precision


def _convert_datetime_list_to_array(datetime_list: List[datetime]) -> np.ndarray:
precision = _detect_highest_datetime_precision(datetime_list)

return np.array(
[np.datetime64(dt, precision) for dt in datetime_list],
dtype=f"datetime64[{precision}]",
)


def convert_udf_returns_to_numpy(udf_return_col: Any) -> Any:
"""Convert UDF columns (output of map_batches) to numpy, if possible.
Expand All @@ -64,6 +90,9 @@ def convert_udf_returns_to_numpy(udf_return_col: Any) -> Any:
udf_return_col = np.expand_dims(udf_return_col[0], axis=0)
return udf_return_col

if all(isinstance(elem, datetime) for elem in udf_return_col):
return _convert_datetime_list_to_array(udf_return_col)

# Try to convert list values into an numpy array via
# np.array(), so users don't need to manually cast.
# NOTE: we don't cast generic iterables, since types like
Expand Down
47 changes: 47 additions & 0 deletions python/ray/data/tests/test_numpy_support.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -46,6 +48,51 @@ def test_list_of_objects(ray_start_regular_shared):
assert_structure_equals(output, np.array([1, 2, 3, UserObj()]))


DATETIME_DAY_PRECISION = datetime(year=2024, month=1, day=1)
DATETIME_HOUR_PRECISION = datetime(year=2024, month=1, day=1, hour=1)
DATETIME_MIN_PRECISION = datetime(year=2024, month=1, day=1, minute=1)
DATETIME_SEC_PRECISION = datetime(year=2024, month=1, day=1, second=1)
DATETIME_MILLISEC_PRECISION = datetime(year=2024, month=1, day=1, microsecond=1000)
DATETIME_MICROSEC_PRECISION = datetime(year=2024, month=1, day=1, microsecond=1)

DATETIME64_DAY_PRECISION = np.datetime64("2024-01-01")
DATETIME64_HOUR_PRECISION = np.datetime64("2024-01-01T01:00", "s")
DATETIME64_MIN_PRECISION = np.datetime64("2024-01-01T00:01", "s")
DATETIME64_SEC_PRECISION = np.datetime64("2024-01-01T00:00:01")
DATETIME64_MILLISEC_PRECISION = np.datetime64("2024-01-01T00:00:00.001")
DATETIME64_MICROSEC_PRECISION = np.datetime64("2024-01-01T00:00:00.000001")


@pytest.mark.parametrize(
"data,expected_output",
[
([DATETIME_DAY_PRECISION], np.array([DATETIME64_DAY_PRECISION])),
([DATETIME_HOUR_PRECISION], np.array([DATETIME64_HOUR_PRECISION])),
([DATETIME_MIN_PRECISION], np.array([DATETIME64_MIN_PRECISION])),
([DATETIME_SEC_PRECISION], np.array([DATETIME64_SEC_PRECISION])),
([DATETIME_MILLISEC_PRECISION], np.array([DATETIME64_MILLISEC_PRECISION])),
([DATETIME_MICROSEC_PRECISION], np.array([DATETIME64_MICROSEC_PRECISION])),
(
[DATETIME_MICROSEC_PRECISION, DATETIME_MILLISEC_PRECISION],
np.array(
[DATETIME64_MICROSEC_PRECISION, DATETIME_MILLISEC_PRECISION],
dtype="datetime64[us]",
),
),
(
[DATETIME_SEC_PRECISION, DATETIME_MILLISEC_PRECISION],
np.array(
[DATETIME64_SEC_PRECISION, DATETIME_MILLISEC_PRECISION],
dtype="datetime64[ms]",
),
),
],
)
def test_list_of_datetimes(data, expected_output, ray_start_regular_shared):
output = do_map_batches(data)
assert_structure_equals(output, expected_output)


def test_array_like(ray_start_regular_shared):
data = torch.Tensor([1, 2, 3])
output = do_map_batches(data)
Expand Down

0 comments on commit 69f3218

Please sign in to comment.