From 07634ddafeaddfe354ff3faf4f14cf8cef433143 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Thu, 21 Nov 2024 09:02:53 -0800 Subject: [PATCH] Update test_sort.py --- .../src/translation/logical_plan.rs | 4 ++- .../src/translation/logical_plan/sort.rs | 17 ++++------- tests/connect/test_sort.py | 29 ++++++++++++++----- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 2bcd8dc061..d2d2039982 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -3,7 +3,9 @@ use eyre::{bail, Context}; use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; -use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range, sort::sort}; +use crate::translation::logical_plan::{ + aggregate::aggregate, project::project, range::range, sort::sort, +}; mod aggregate; mod project; diff --git a/src/daft-connect/src/translation/logical_plan/sort.rs b/src/daft-connect/src/translation/logical_plan/sort.rs index 15baccbcd8..a6657fb785 100644 --- a/src/daft-connect/src/translation/logical_plan/sort.rs +++ b/src/daft-connect/src/translation/logical_plan/sort.rs @@ -47,31 +47,24 @@ pub fn sort(sort: spark_connect::Sort) -> eyre::Result { let null_ordering = NullOrdering::try_from(*null_ordering) .wrap_err_with(|| format!("Invalid null ordering: {null_ordering:?}"))?; - // todo(correctness): is this correct? let is_descending = match direction { SortDirection::Unspecified => { - bail!("Unspecified sort direction is not yet supported") + // default to ascending order + false } SortDirection::Ascending => false, SortDirection::Descending => true, }; - // todo(correctness): is this correct? - let tentative_sort_nulls_first = match null_ordering { + let sort_nulls_first = match null_ordering { NullOrdering::SortNullsUnspecified => { - bail!("Unspecified null ordering is not yet supported") + // default: match is_descending + is_descending } NullOrdering::SortNullsFirst => true, NullOrdering::SortNullsLast => false, }; - // https://github.com/Eventual-Inc/Daft/blob/7922d2d810ff92b00008d877aa9a6553bc0dedab/src/daft-core/src/utils/mod.rs#L10-L19 - let sort_nulls_first = is_descending; - - if sort_nulls_first != tentative_sort_nulls_first { - warn!("Ignoring nulls_first {sort_nulls_first}; not yet implemented"); - } - sort_by.push(child); descending.push(is_descending); nulls_first.push(sort_nulls_first); diff --git a/tests/connect/test_sort.py b/tests/connect/test_sort.py index 653510db18..043992c1e3 100644 --- a/tests/connect/test_sort.py +++ b/tests/connect/test_sort.py @@ -3,14 +3,27 @@ from pyspark.sql.functions import col -def test_sort(spark_session): - # Create DataFrame from range(10) - df = spark_session.range(10) +def test_sort_multiple_columns(spark_session): + # Create DataFrame with two columns using range + df = spark_session.range(4).select(col("id").alias("num"), col("id").alias("letter")) - # Sort the DataFrame by 'id' column in descending order - df_sorted = df.sort(col("id").desc()) + # Sort by multiple columns + df_sorted = df.sort(col("num").asc(), col("letter").desc()) # Verify the DataFrame is sorted correctly - df_pandas = df.toPandas() - df_sorted_pandas = df_sorted.toPandas() - assert df_sorted_pandas["id"].equals(df_pandas["id"].sort_values(ascending=False).reset_index(drop=True)), "Data should be sorted in descending order" + actual = df_sorted.collect() + expected = [(0, 0), (1, 1), (2, 2), (3, 3)] + assert [(row.num, row.letter) for row in actual] == expected + + +def test_sort_mixed_order(spark_session): + # Create DataFrame with two columns using range + df = spark_session.range(4).select(col("id").alias("num"), col("id").alias("letter")) + + # Sort with mixed ascending/descending order + df_sorted = df.sort(col("num").desc(), col("letter").asc()) + + # Verify the DataFrame is sorted correctly + actual = df_sorted.collect() + expected = [(3, 3), (2, 2), (1, 1), (0, 0)] + assert [(row.num, row.letter) for row in actual] == expected