Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor shuffle method to handle invalid columns #1091

Merged
merged 8 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,8 +896,20 @@ def shuffle(
raise TypeError(
"index must be aligned with the DataFrame to use as shuffle index."
)
elif pd.api.types.is_list_like(on) and not is_dask_collection(on):
on = list(on)
else:
if pd.api.types.is_list_like(on) and not is_dask_collection(on):
on = list(on)
elif isinstance(on, str) or isinstance(on, int):
on = [on]
bad_cols = [
index_col
for index_col in on
if (index_col not in self.columns) and (index_col != self.index.name)
]
if bad_cols:
raise KeyError(
f"Cannot shuffle on {bad_cols}, column(s) not in dataframe to shuffle"
)

if (shuffle_method or get_default_shuffle_method()) == "p2p":
from distributed.shuffle._arrow import check_dtype_support
Expand All @@ -911,7 +923,6 @@ def shuffle(
raise TypeError(
f"p2p requires all column names to be str, found: {unsupported}",
)

# Returned shuffled result
alex-rakowski marked this conversation as resolved.
Show resolved Hide resolved
return new_collection(
RearrangeByColumn(
Expand Down
30 changes: 30 additions & 0 deletions dask_expr/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,36 @@ def test_task_shuffle_index(npartitions, max_branch, pdf):
assert sorted(df3.compute().values) == list(range(20))


def test_shuffle_str_column_not_in_dataframe(df):
with pytest.raises(
KeyError,
match="Cannot shuffle on",
) as execinfo:
df.shuffle(on="z")
assert "z" in str(execinfo.value)


def test_shuffle_mixed_list_column_not_in_dataframe(df):
with pytest.raises(
KeyError,
match="Cannot shuffle on",
) as execinfo:
df.shuffle(["x", "z"])
assert "z" in str(execinfo.value)
assert "x" not in str(execinfo.value)


def test_shuffle_list_column_not_in_dataframe(df):
with pytest.raises(KeyError, match=r"Cannot shuffle on") as excinfo:
df.shuffle(["zz", "z"])
assert "z" in str(excinfo.value)
assert "zz" in str(excinfo.value)


def test_shuffle_column_columns(df):
df.shuffle(df.columns[-1])


def test_shuffle_column_projection(df):
df2 = df.shuffle("x")[["x"]].simplify()

Expand Down
Loading