Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor shuffle method to handle invalid columns #1091

Merged
merged 8 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,8 +896,24 @@ def shuffle(
raise TypeError(
"index must be aligned with the DataFrame to use as shuffle index."
)
elif pd.api.types.is_list_like(on) and not is_dask_collection(on):
on = list(on)
else:
if pd.api.types.is_list_like(on) and not is_dask_collection(on):
on = list(on)
elif isinstance(on, str):
on = [on]
bad_cols = [
index_col
for index_col in on
if (index_col not in self.columns) and (index_col != self.index.name)
]
if len(bad_cols) == 1:
alex-rakowski marked this conversation as resolved.
Show resolved Hide resolved
raise KeyError(
f"Cannot shuffle on '{bad_cols[0]}', as it is not in target DataFrame columns"
)
elif len(bad_cols) > 1:
raise KeyError(
f"Cannot shuffle on {bad_cols}, as they are not in target DataFrame columns"
)

if (shuffle_method or get_default_shuffle_method()) == "p2p":
from distributed.shuffle._arrow import check_dtype_support
Expand All @@ -912,7 +928,6 @@ def shuffle(
f"p2p requires all column names to be str, found: {unsupported}",
)

# Returned shuffled result
alex-rakowski marked this conversation as resolved.
Show resolved Hide resolved
return new_collection(
RearrangeByColumn(
self,
Expand Down
30 changes: 30 additions & 0 deletions dask_expr/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,36 @@ def test_task_shuffle_index(npartitions, max_branch, pdf):
assert sorted(df3.compute().values) == list(range(20))


def test_shuffle_str_column_not_in_dataframe(df):
# ddf = from_pandas(pdf, npartitions=10)
alex-rakowski marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(
KeyError,
match="Cannot shuffle on 'z', as it is not in target DataFrame columns",
):
df.shuffle(on="z") # .compute()


def test_shuffle_mixed_list_column_not_in_dataframe(df):
# not all cols in list are not in dataframe
with pytest.raises(
KeyError,
match="Cannot shuffle on 'z', as it is not in target DataFrame columns",
):
df.shuffle(["x", "z"])


def test_shuffle_list_column_not_in_dataframe(df):
# all cols in list are not in dataframe
alex-rakowski marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(KeyError, match=r"Cannot shuffle on") as excinfo:
df.shuffle(["zz", "z"])
assert "z" in str(excinfo.value)
assert "zz" in str(excinfo.value)


def test_shuffle_column_columns(df):
df.shuffle(df.columns[-1])


def test_shuffle_column_projection(df):
df2 = df.shuffle("x")[["x"]].simplify()

Expand Down
Loading