diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 9e35673a..0ff2208b 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -3,6 +3,7 @@ import operator import numpy as np +from dask._task_spec import Task, TaskRef from dask.dataframe.dispatch import make_meta, meta_nonempty from dask.dataframe.multi import ( _concat_wrapper, @@ -638,10 +639,10 @@ def _layer(self) -> dict: transfer_keys_right = list() func = create_assign_index_merge_transfer() for i in range(self.left.npartitions): - transfer_keys_left.append((transfer_name_left, i)) - dsk[(transfer_name_left, i)] = ( + t = Task( + (transfer_name_left, i), func, - (self.left._name, i), + TaskRef((self.left._name, i)), self.shuffle_left_on, _HASH_COLUMN_NAME, self.npartitions, @@ -651,11 +652,14 @@ def _layer(self) -> dict: self._partitions, self.left_index, ) + dsk[t.key] = t + transfer_keys_left.append(t.ref()) + for i in range(self.right.npartitions): - transfer_keys_right.append((transfer_name_right, i)) - dsk[(transfer_name_right, i)] = ( + t = Task( + (transfer_name_right, i), func, - (self.right._name, i), + TaskRef((self.right._name, i)), self.shuffle_right_on, _HASH_COLUMN_NAME, self.npartitions, @@ -665,22 +669,28 @@ def _layer(self) -> dict: self._partitions, self.right_index, ) + dsk[t.key] = t + transfer_keys_right.append(t.ref()) + + barrier_left = Task( + _barrier_key_left, p2p_barrier, token_left, transfer_keys_left + ) + dsk[barrier_left.key] = barrier_left - dsk[_barrier_key_left] = (p2p_barrier, token_left, transfer_keys_left) - dsk[_barrier_key_right] = ( - p2p_barrier, - token_right, - transfer_keys_right, + barrier_right = Task( + _barrier_key_right, p2p_barrier, token_right, transfer_keys_right ) + dsk[barrier_right.key] = barrier_right for part_out in self._partitions: - dsk[(self._name, part_out)] = ( + t = Task( + (self._name, part_out), merge_unpack, token_left, token_right, part_out, - _barrier_key_left, - _barrier_key_right, + barrier_left.ref(), + barrier_right.ref(), self.how, self.left_on, self.right_on, @@ -690,6 +700,7 @@ def _layer(self) -> dict: self.right_index, self.indicator, ) + dsk[t.key] = t return dsk def _simplify_up(self, parent, dependents): diff --git a/dask_expr/_shuffle.py b/dask_expr/_shuffle.py index ca51f6ba..a5795710 100644 --- a/dask_expr/_shuffle.py +++ b/dask_expr/_shuffle.py @@ -9,6 +9,7 @@ import pandas as pd import tlz as toolz from dask import compute +from dask._task_spec import Task, TaskRef from dask.dataframe.core import _concat, make_meta from dask.dataframe.dispatch import is_categorical_dtype from dask.dataframe.shuffle import ( @@ -575,7 +576,6 @@ def _layer(self): token = self._name.split("-")[-1] _barrier_key = barrier_key(ShuffleId(token)) name = "shuffle-transfer-" + token - transfer_keys = list() parts_out = ( self._partitions if self._filtered else list(range(self.npartitions_out)) @@ -585,11 +585,12 @@ def _layer(self): set(self._partitions) if self._filtered else self.npartitions_out ) + transfer_keys = list() for i in range(self.frame.npartitions): - transfer_keys.append((name, i)) - dsk[(name, i)] = ( + t = Task( + (name, i), _shuffle_transfer, - (self.frame._name, i), + TaskRef((self.frame._name, i)), token, i, self.npartitions_out, @@ -599,18 +600,23 @@ def _layer(self): True, True, ) + dsk[t.key] = t + transfer_keys.append(t.ref()) - dsk[_barrier_key] = (p2p_barrier, token, transfer_keys) + barrier = Task(_barrier_key, p2p_barrier, token, transfer_keys) + dsk[barrier.key] = barrier # TODO: Decompose p2p Into transfer/barrier + unpack name = self._name for i, part_out in enumerate(parts_out): - dsk[(name, i)] = ( + t = Task( + (name, i), shuffle_unpack, token, part_out, - _barrier_key, + barrier.ref(), ) + dsk[t.key] = t return dsk