From 717284887a514126479815b3467cfb79fb82173b Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Thu, 29 Jun 2023 10:40:34 -0700 Subject: [PATCH] First pass at RANK support --- dask_sql/physical/rel/logical/window.py | 20 ++++++++++++++++++-- tests/integration/test_compatibility.py | 13 ++++--------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/dask_sql/physical/rel/logical/window.py b/dask_sql/physical/rel/logical/window.py index 331876c49..9aabd3dbc 100644 --- a/dask_sql/physical/rel/logical/window.py +++ b/dask_sql/physical/rel/logical/window.py @@ -198,10 +198,22 @@ def map_on_each_group( # Calculate the results new_columns = {} for f, new_column_name, temporary_operand_columns in operations: - if f is None: + if f == "row_number": # This is the row_number operator. # We do not need to do any windowing column_result = range(1, len(partitioned_group) + 1) + elif f == "rank": + column_result = partitioned_group.rank(method="min", na_option="top").iloc[ + :, 0 + ] + elif f == "dense_rank": + column_result = partitioned_group.rank( + method="dense", na_option="top" + ).iloc[:, 0] + elif f == "percent_rank": + column_result = partitioned_group.rank( + method="min", na_option="top", pct=True + ).iloc[:, 0] else: column_result = f(windowed_group, *temporary_operand_columns) @@ -226,7 +238,6 @@ class DaskWindowPlugin(BaseRelPlugin): class_name = "Window" OPERATION_MAPPING = { - "row_number": None, # That is the easiest one: we do not even need to have any windowing. We therefore threat it separately "$sum0": SumOperation(), "sum": SumOperation(), "count": CountOperation(), @@ -236,6 +247,11 @@ class DaskWindowPlugin(BaseRelPlugin): "first_value": FirstValueOperation(), "last_value": LastValueOperation(), "avg": AvgOperation(), + # operations that don't require windowing + "row_number": "row_number", + "rank": "rank", + "dense_rank": "dense_rank", + "percent_rank": "percent_rank", } def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: diff --git a/tests/integration/test_compatibility.py b/tests/integration/test_compatibility.py index b34d64bbb..b39a34021 100644 --- a/tests/integration/test_compatibility.py +++ b/tests/integration/test_compatibility.py @@ -586,9 +586,6 @@ def test_window_row_number_partition_by(): ) -@pytest.mark.xfail( - reason="Need to implement rank/lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878" -) def test_window_ranks(): a = make_rand_df(100, a=int, b=(float, 50), c=(str, 50)) eq_sqlite( @@ -596,16 +593,14 @@ def test_window_ranks(): SELECT *, RANK() OVER (PARTITION BY a ORDER BY b DESC NULLS FIRST, c) AS a1, DENSE_RANK() OVER (ORDER BY a ASC, b DESC NULLS LAST, c DESC) AS a2, - PERCENT_RANK() OVER (ORDER BY a ASC, b ASC NULLS LAST, c) AS a4 + PERCENT_RANK() OVER (ORDER BY a ASC, b ASC NULLS LAST, c) AS a3 FROM a + ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=a, ) -@pytest.mark.xfail( - reason="Need to implement rank/lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878" -) def test_window_ranks_partition_by(): a = make_rand_df(100, a=int, b=(float, 50), c=(str, 50)) eq_sqlite( @@ -624,7 +619,7 @@ def test_window_ranks_partition_by(): @pytest.mark.xfail( - reason="Need to implement rank/lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878" + reason="Need to implement lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878" ) def test_window_lead_lag(): a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) @@ -647,7 +642,7 @@ def test_window_lead_lag(): @pytest.mark.xfail( - reason="Need to implement rank/lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878" + reason="Need to implement lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878" ) def test_window_lead_lag_partition_by(): a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50))