From 7d32168e767c4d1084bff094a924bb2f0927f99d Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Mon, 24 Jun 2024 13:21:34 +0100 Subject: [PATCH 1/8] chore: Refactor shuffle method to handle invalid columns --- dask_expr/_collection.py | 15 ++++++++++++--- dask_expr/tests/test_shuffle.py | 26 ++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index aa538820..1c39182e 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -896,8 +896,18 @@ def shuffle( raise TypeError( "index must be aligned with the DataFrame to use as shuffle index." ) - elif pd.api.types.is_list_like(on) and not is_dask_collection(on): - on = list(on) + else: + if pd.api.types.is_list_like(on) and not is_dask_collection(on): + on = list(on) + elif isinstance(on, str): + on =[on] #this doesn't split the string into characters, it just makes a list with the string as the only element + # Check if 'on' is a valid column + bad_cols = [index_col for index_col in on if (index_col not in self.columns) and (index_col != self.index.name)] + # Adding this check so it collects all bad columns at once rather than requiring multiple runs + if len(bad_cols) == 1: + raise KeyError(f"Cannot shuffle on '{bad_cols[0]}', as it is not in target DataFrame columns") + elif len(bad_cols) > 1: + raise KeyError(f"Cannot shuffle on {bad_cols}, as they are not in target DataFrame columns") if (shuffle_method or get_default_shuffle_method()) == "p2p": from distributed.shuffle._arrow import check_dtype_support @@ -912,7 +922,6 @@ def shuffle( f"p2p requires all column names to be str, found: {unsupported}", ) - # Returned shuffled result return new_collection( RearrangeByColumn( self, diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index 965344de..d9cc74ed 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -130,6 +130,32 @@ def test_task_shuffle_index(npartitions, max_branch, pdf): assert sorted(df3.compute().values) == list(range(20)) +def test_shuffle_str_column_not_in_dataframe(df): + # ddf = from_pandas(pdf, npartitions=10) + with pytest.raises(KeyError, + match="Cannot shuffle on 'z', as it is not in target DataFrame columns" + ): + df.shuffle(on="z")#.compute() + +def test_shuffle_mixed_list_column_not_in_dataframe(df): + # not all cols in list are not in dataframe + with pytest.raises(KeyError, + match= "Cannot shuffle on 'z', as it is not in target DataFrame columns" + ): + df.shuffle(["x", "z"]) + +def test_shuffle_list_column_not_in_dataframe(df): + # all cols in list are not in dataframe + with pytest.raises(KeyError, + + match=r"Cannot shuffle on") as excinfo: + df.shuffle(["zz", "z"]) + assert "z" in str(excinfo.value) + assert "zz" in str(excinfo.value) + +def test_shuffle_column_columns(df): + df.shuffle(df.columns[-1]) + def test_shuffle_column_projection(df): df2 = df.shuffle("x")[["x"]].simplify() From fbd977bcc7eaea07bc51d2a7c0dd93a12b4ee31d Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Mon, 24 Jun 2024 13:25:53 +0100 Subject: [PATCH 2/8] formatting --- dask_expr/_collection.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index 1c39182e..be3aea5f 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -900,14 +900,20 @@ def shuffle( if pd.api.types.is_list_like(on) and not is_dask_collection(on): on = list(on) elif isinstance(on, str): - on =[on] #this doesn't split the string into characters, it just makes a list with the string as the only element - # Check if 'on' is a valid column - bad_cols = [index_col for index_col in on if (index_col not in self.columns) and (index_col != self.index.name)] - # Adding this check so it collects all bad columns at once rather than requiring multiple runs + on = [on] + bad_cols = [ + index_col + for index_col in on + if (index_col not in self.columns) and (index_col != self.index.name) + ] if len(bad_cols) == 1: - raise KeyError(f"Cannot shuffle on '{bad_cols[0]}', as it is not in target DataFrame columns") + raise KeyError( + f"Cannot shuffle on '{bad_cols[0]}', as it is not in target DataFrame columns" + ) elif len(bad_cols) > 1: - raise KeyError(f"Cannot shuffle on {bad_cols}, as they are not in target DataFrame columns") + raise KeyError( + f"Cannot shuffle on {bad_cols}, as they are not in target DataFrame columns" + ) if (shuffle_method or get_default_shuffle_method()) == "p2p": from distributed.shuffle._arrow import check_dtype_support From 7d8246d959bedba49431e8063f4d0de18d37c630 Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Mon, 24 Jun 2024 13:26:31 +0100 Subject: [PATCH 3/8] formatting --- dask_expr/tests/test_shuffle.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index d9cc74ed..f9925cc6 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -132,30 +132,34 @@ def test_task_shuffle_index(npartitions, max_branch, pdf): def test_shuffle_str_column_not_in_dataframe(df): # ddf = from_pandas(pdf, npartitions=10) - with pytest.raises(KeyError, - match="Cannot shuffle on 'z', as it is not in target DataFrame columns" + with pytest.raises( + KeyError, + match="Cannot shuffle on 'z', as it is not in target DataFrame columns", ): - df.shuffle(on="z")#.compute() + df.shuffle(on="z") # .compute() + def test_shuffle_mixed_list_column_not_in_dataframe(df): # not all cols in list are not in dataframe - with pytest.raises(KeyError, - match= "Cannot shuffle on 'z', as it is not in target DataFrame columns" - ): + with pytest.raises( + KeyError, + match="Cannot shuffle on 'z', as it is not in target DataFrame columns", + ): df.shuffle(["x", "z"]) + def test_shuffle_list_column_not_in_dataframe(df): # all cols in list are not in dataframe - with pytest.raises(KeyError, - - match=r"Cannot shuffle on") as excinfo: + with pytest.raises(KeyError, match=r"Cannot shuffle on") as excinfo: df.shuffle(["zz", "z"]) assert "z" in str(excinfo.value) assert "zz" in str(excinfo.value) + def test_shuffle_column_columns(df): df.shuffle(df.columns[-1]) + def test_shuffle_column_projection(df): df2 = df.shuffle("x")[["x"]].simplify() From bc3c9d7c0991c97075d555d666e4663cc1709d15 Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Mon, 24 Jun 2024 13:54:05 +0100 Subject: [PATCH 4/8] Simplifiying error message and handling ints --- dask_expr/_collection.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index be3aea5f..e9b96b8e 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -899,20 +899,16 @@ def shuffle( else: if pd.api.types.is_list_like(on) and not is_dask_collection(on): on = list(on) - elif isinstance(on, str): + elif isinstance(on, str) or isinstance(on, int): on = [on] bad_cols = [ index_col for index_col in on if (index_col not in self.columns) and (index_col != self.index.name) ] - if len(bad_cols) == 1: + if bad_cols: raise KeyError( - f"Cannot shuffle on '{bad_cols[0]}', as it is not in target DataFrame columns" - ) - elif len(bad_cols) > 1: - raise KeyError( - f"Cannot shuffle on {bad_cols}, as they are not in target DataFrame columns" + f"Cannot shuffle on {bad_cols}, column(s) not in dataframe to shuffle" ) if (shuffle_method or get_default_shuffle_method()) == "p2p": From dfc414633b2a0bc79fcdf0b06cec13cc8c9436c3 Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Mon, 24 Jun 2024 13:54:35 +0100 Subject: [PATCH 5/8] handling new error style --- dask_expr/tests/test_shuffle.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index f9925cc6..a13b4ae4 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -134,18 +134,21 @@ def test_shuffle_str_column_not_in_dataframe(df): # ddf = from_pandas(pdf, npartitions=10) with pytest.raises( KeyError, - match="Cannot shuffle on 'z', as it is not in target DataFrame columns", - ): + match="Cannot shuffle on", + ) as execinfo: df.shuffle(on="z") # .compute() + assert "z" in str(execinfo.value) def test_shuffle_mixed_list_column_not_in_dataframe(df): # not all cols in list are not in dataframe with pytest.raises( KeyError, - match="Cannot shuffle on 'z', as it is not in target DataFrame columns", - ): + match="Cannot shuffle on", + ) as execinfo: df.shuffle(["x", "z"]) + assert "z" in str(execinfo.value) + assert "x" not in str(execinfo.value) def test_shuffle_list_column_not_in_dataframe(df): From 5ee5d5c75722556901cfee1cec81a8c79b1934e8 Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Mon, 24 Jun 2024 15:17:44 +0100 Subject: [PATCH 6/8] clean up comments --- dask_expr/tests/test_shuffle.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index a13b4ae4..d3c8b340 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -131,12 +131,11 @@ def test_task_shuffle_index(npartitions, max_branch, pdf): def test_shuffle_str_column_not_in_dataframe(df): - # ddf = from_pandas(pdf, npartitions=10) with pytest.raises( KeyError, match="Cannot shuffle on", ) as execinfo: - df.shuffle(on="z") # .compute() + df.shuffle(on="z") assert "z" in str(execinfo.value) From b1dc24cfc7a1ef08c53d93b5b9327da69a82be39 Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Mon, 24 Jun 2024 16:09:23 +0100 Subject: [PATCH 7/8] removing comments --- dask_expr/tests/test_shuffle.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index d3c8b340..c3c9c217 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -140,7 +140,6 @@ def test_shuffle_str_column_not_in_dataframe(df): def test_shuffle_mixed_list_column_not_in_dataframe(df): - # not all cols in list are not in dataframe with pytest.raises( KeyError, match="Cannot shuffle on", @@ -151,7 +150,6 @@ def test_shuffle_mixed_list_column_not_in_dataframe(df): def test_shuffle_list_column_not_in_dataframe(df): - # all cols in list are not in dataframe with pytest.raises(KeyError, match=r"Cannot shuffle on") as excinfo: df.shuffle(["zz", "z"]) assert "z" in str(excinfo.value) From 4d1fa088aa2838a19728c52deae27dfadc6e7cf9 Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Mon, 24 Jun 2024 16:10:50 +0100 Subject: [PATCH 8/8] adding back in comment --- dask_expr/_collection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index e9b96b8e..370a3afc 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -923,7 +923,7 @@ def shuffle( raise TypeError( f"p2p requires all column names to be str, found: {unsupported}", ) - + # Returned shuffled result return new_collection( RearrangeByColumn( self,