diff --git a/dask_ml/utils.py b/dask_ml/utils.py index abeaa58a0..83e89de0d 100644 --- a/dask_ml/utils.py +++ b/dask_ml/utils.py @@ -241,8 +241,8 @@ def check_random_state(random_state): raise TypeError("Unexpected type '{}'".format(type(random_state))) -def check_matching_blocks(*arrays): - """Check that the partitioning structure for many arrays matches. +def _check_matching_blocks(*arrays, check_first_dim_only=False): + """Helper function to check blocks match across *arrays. Parameters ---------- @@ -252,18 +252,22 @@ def check_matching_blocks(*arrays): * Dask Array * Dask DataFrame * Dask Series + check_first_dim_only: bool, default false + Whether to only checks the chunks along the first dimension. Only applies + if all the arrays are dask arrays. """ if len(arrays) <= 1: return + slice_to_check = slice(0, 1, 1) if check_first_dim_only else slice(None, None) if all(isinstance(x, da.Array) for x in arrays): # TODO: unknown chunks, ensure blocks match, or just raise (configurable) - chunks = arrays[0].chunks + chunks = arrays[0].chunks[slice_to_check] for array in arrays[1:]: - if array.chunks != chunks: + if array.chunks[slice_to_check] != chunks: raise ValueError( "Mismatched chunks. {} != {}".format(chunks, array.chunks) ) - + # Divisions correspond to the index (first_dim) so no need to use slice_to_check elif all(isinstance(x, (dd.Series, dd.DataFrame)) for x in arrays): divisions = arrays[0].divisions for array in arrays[1:]: @@ -275,6 +279,21 @@ def check_matching_blocks(*arrays): raise ValueError("Unexpected types {}.".format({type(x) for x in arrays})) +def check_matching_blocks(*arrays): + """Check that the partitioning structure for many arrays matches. + + Parameters + ---------- + *arrays : Sequence of array-likes + This includes + + * Dask Array + * Dask DataFrame + * Dask Series + """ + _check_matching_blocks(*arrays, check_first_dim_only=False) + + def check_X_y( X, y, @@ -433,8 +452,8 @@ def _check_y(y, multi_output=False, y_numeric=False): def check_consistent_length(*arrays): - # TODO: check divisions, chunks, etc. - pass + """Check that blocks match for arrays and divisions match for dataframes.""" + _check_matching_blocks(*arrays, check_first_dim_only=True) def check_chunks(n_samples, n_features, chunks=None): diff --git a/tests/test_utils.py b/tests/test_utils.py index e26097748..fc64c73d0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,6 +15,7 @@ assert_estimator_equal, check_array, check_chunks, + check_consistent_length, check_matching_blocks, check_random_state, handle_zeros_in_scale, @@ -234,3 +235,62 @@ def test_matching_blocks_ok(arrays): def test_matching_blocks_raises(arrays): with pytest.raises(ValueError): check_matching_blocks(*arrays) + + +@pytest.mark.parametrize( + "arrays", + [ + ( + da.random.uniform(size=(10, 10), chunks=(10, 10)), + da.random.uniform(size=10, chunks=10), + ), + ( + da.random.uniform(size=(50, 10), chunks=(50, 10)), + da.random.uniform(size=50, chunks=50), + ), + ( + dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), 2) + .reset_index() + .to_dask_array(), + dd.from_pandas(pd.Series([1, 2, 3]), 2).reset_index().to_dask_array(), + ), + ( + dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), 2), + dd.from_pandas(pd.Series([1, 2, 3]), 2), + ), + # Allow known and unknown? + pytest.param( + ( + dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), 2) + .reset_index() + .to_dask_array(), + dd.from_pandas(pd.Series([1, 2, 3]), 2).reset_index(), + ), + marks=pytest.mark.xfail(reason="Known and unknown blocks."), + ), + ], +) +def test_check_consistent_length_ok(arrays): + check_consistent_length(*arrays) + + +@pytest.mark.parametrize( + "arrays", + [ + ( + da.random.uniform(size=(10, 10), chunks=(10, 10)), + da.random.uniform(size=8, chunks=8), + ), + ( + da.random.uniform(size=(100, 10), chunks=(100, 10)), + da.random.uniform(size=50, chunks=50), + ), + ( + dd.from_pandas(pd.DataFrame({"a": [1, 2, 3, 4]}), 4), + dd.from_pandas(pd.Series([1, 2, 3]), 2), + ), + ], +) +def test_check_consistent_length_raises(arrays): + with pytest.raises(ValueError): + check_consistent_length(*arrays)