diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index aa538820..370a3afc 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -896,8 +896,20 @@ def shuffle( raise TypeError( "index must be aligned with the DataFrame to use as shuffle index." ) - elif pd.api.types.is_list_like(on) and not is_dask_collection(on): - on = list(on) + else: + if pd.api.types.is_list_like(on) and not is_dask_collection(on): + on = list(on) + elif isinstance(on, str) or isinstance(on, int): + on = [on] + bad_cols = [ + index_col + for index_col in on + if (index_col not in self.columns) and (index_col != self.index.name) + ] + if bad_cols: + raise KeyError( + f"Cannot shuffle on {bad_cols}, column(s) not in dataframe to shuffle" + ) if (shuffle_method or get_default_shuffle_method()) == "p2p": from distributed.shuffle._arrow import check_dtype_support @@ -911,7 +923,6 @@ def shuffle( raise TypeError( f"p2p requires all column names to be str, found: {unsupported}", ) - # Returned shuffled result return new_collection( RearrangeByColumn( diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index 965344de..c3c9c217 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -130,6 +130,36 @@ def test_task_shuffle_index(npartitions, max_branch, pdf): assert sorted(df3.compute().values) == list(range(20)) +def test_shuffle_str_column_not_in_dataframe(df): + with pytest.raises( + KeyError, + match="Cannot shuffle on", + ) as execinfo: + df.shuffle(on="z") + assert "z" in str(execinfo.value) + + +def test_shuffle_mixed_list_column_not_in_dataframe(df): + with pytest.raises( + KeyError, + match="Cannot shuffle on", + ) as execinfo: + df.shuffle(["x", "z"]) + assert "z" in str(execinfo.value) + assert "x" not in str(execinfo.value) + + +def test_shuffle_list_column_not_in_dataframe(df): + with pytest.raises(KeyError, match=r"Cannot shuffle on") as excinfo: + df.shuffle(["zz", "z"]) + assert "z" in str(excinfo.value) + assert "zz" in str(excinfo.value) + + +def test_shuffle_column_columns(df): + df.shuffle(df.columns[-1]) + + def test_shuffle_column_projection(df): df2 = df.shuffle("x")[["x"]].simplify()