diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index c8ae80a75e..4eaea179df 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -526,6 +526,12 @@ def handle_unpack_errors(id: ShuffleId) -> Iterator[None]: raise RuntimeError(f"P2P shuffling {id} failed during unpack phase") from e +def _handle_datetime(buf: Any) -> Any: + if hasattr(buf, "dtype") and buf.dtype.kind in "Mm": + return buf.view("u8") + return buf + + def _mean_shard_size(shards: Iterable) -> int: """Return estimated mean size in bytes of each shard""" size = 0 @@ -534,6 +540,7 @@ def _mean_shard_size(shards: Iterable) -> int: if not isinstance(shard, int): # This also asserts that shard is a Buffer and that we didn't forget # a container or metadata type above + shard = _handle_datetime(shard) size += memoryview(shard).nbytes count += 1 if count == 10: diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index bf55b45457..f2cd8564cc 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -1524,3 +1524,15 @@ def transition(self, key, start, finish, *args, stimulus_id, **kwargs): min_count = min(counts.values()) max_count = max(counts.values()) assert min_count >= max_count, counts + + +@pytest.mark.parametrize("method", ["tasks", "p2p"]) +@gen_cluster(client=True) +async def test_rechunk_datetime(c, s, *ws, method): + pd = pytest.importorskip("pandas") + + x = pd.date_range("2005-01-01", "2005-01-10").to_numpy(dtype="datetime64[ns]") + dx = da.from_array(x, chunks=10) + result = dx.rechunk(2, method=method) + result = await c.compute(result) + np.testing.assert_array_equal(x, result)