Skip to content

Commit

Permalink
Refactor shuffle method to handle invalid columns (#1091)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-rakowski authored Jun 24, 2024
1 parent 4659f61 commit cb121cd
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
17 changes: 14 additions & 3 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,8 +896,20 @@ def shuffle(
raise TypeError(
"index must be aligned with the DataFrame to use as shuffle index."
)
elif pd.api.types.is_list_like(on) and not is_dask_collection(on):
on = list(on)
else:
if pd.api.types.is_list_like(on) and not is_dask_collection(on):
on = list(on)
elif isinstance(on, str) or isinstance(on, int):
on = [on]
bad_cols = [
index_col
for index_col in on
if (index_col not in self.columns) and (index_col != self.index.name)
]
if bad_cols:
raise KeyError(
f"Cannot shuffle on {bad_cols}, column(s) not in dataframe to shuffle"
)

if (shuffle_method or get_default_shuffle_method()) == "p2p":
from distributed.shuffle._arrow import check_dtype_support
Expand All @@ -911,7 +923,6 @@ def shuffle(
raise TypeError(
f"p2p requires all column names to be str, found: {unsupported}",
)

# Returned shuffled result
return new_collection(
RearrangeByColumn(
Expand Down
30 changes: 30 additions & 0 deletions dask_expr/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,36 @@ def test_task_shuffle_index(npartitions, max_branch, pdf):
assert sorted(df3.compute().values) == list(range(20))


def test_shuffle_str_column_not_in_dataframe(df):
with pytest.raises(
KeyError,
match="Cannot shuffle on",
) as execinfo:
df.shuffle(on="z")
assert "z" in str(execinfo.value)


def test_shuffle_mixed_list_column_not_in_dataframe(df):
with pytest.raises(
KeyError,
match="Cannot shuffle on",
) as execinfo:
df.shuffle(["x", "z"])
assert "z" in str(execinfo.value)
assert "x" not in str(execinfo.value)


def test_shuffle_list_column_not_in_dataframe(df):
with pytest.raises(KeyError, match=r"Cannot shuffle on") as excinfo:
df.shuffle(["zz", "z"])
assert "z" in str(excinfo.value)
assert "zz" in str(excinfo.value)


def test_shuffle_column_columns(df):
df.shuffle(df.columns[-1])


def test_shuffle_column_projection(df):
df2 = df.shuffle("x")[["x"]].simplify()

Expand Down

0 comments on commit cb121cd

Please sign in to comment.