From 3b05e9a21455daa187d1f7ffa6c6f336589cc71d Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 01:23:05 -0800 Subject: [PATCH] [FEAT]: connect: `df.sort` --- .../src/translation/logical_plan.rs | 4 +- .../src/translation/logical_plan/sort.rs | 83 +++++++++++++++++++ tests/connect/test_sort.py | 16 ++++ 3 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 src/daft-connect/src/translation/logical_plan/sort.rs create mode 100644 tests/connect/test_sort.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 947e0cd0d3..5e4c8ac5ad 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -3,11 +3,12 @@ use eyre::{bail, Context}; use spark_connect::{relation::RelType, Relation}; use tracing::warn; -use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; +use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range, sort::sort}; mod aggregate; mod project; mod range; +mod sort; pub fn to_logical_plan(relation: Relation) -> eyre::Result { if let Some(common) = relation.common { @@ -24,6 +25,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Aggregate(a) => { aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") } + RelType::Sort(s) => sort(*s).wrap_err("Failed to apply sort to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/sort.rs b/src/daft-connect/src/translation/logical_plan/sort.rs new file mode 100644 index 0000000000..15baccbcd8 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/sort.rs @@ -0,0 +1,83 @@ +use eyre::{bail, WrapErr}; +use spark_connect::expression::{ + sort_order::{NullOrdering, SortDirection}, + SortOrder, +}; +use tracing::warn; + +use crate::translation::{logical_plan::LogicalPlanBuilder, to_daft_expr, to_logical_plan}; + +pub fn sort(sort: spark_connect::Sort) -> eyre::Result { + let spark_connect::Sort { + input, + order, + is_global, + } = sort; + + if let Some(is_global) = is_global { + warn!("Ignoring is_global {is_global}; not yet implemented"); + } + + let Some(input) = input else { + bail!("Input is required"); + }; + + let plan = to_logical_plan(*input)?; + + let mut sort_by = Vec::new(); + let mut descending = Vec::new(); + let mut nulls_first = Vec::new(); + + for o in &order { + let SortOrder { + child, + direction, + null_ordering, + } = o; + + let Some(child) = child else { + bail!("Child is required"); + }; + + let child = to_daft_expr(child)?; + + let direction = SortDirection::try_from(*direction) + .wrap_err_with(|| format!("Invalid sort direction: {direction:?}"))?; + + let null_ordering = NullOrdering::try_from(*null_ordering) + .wrap_err_with(|| format!("Invalid null ordering: {null_ordering:?}"))?; + + // todo(correctness): is this correct? + let is_descending = match direction { + SortDirection::Unspecified => { + bail!("Unspecified sort direction is not yet supported") + } + SortDirection::Ascending => false, + SortDirection::Descending => true, + }; + + // todo(correctness): is this correct? + let tentative_sort_nulls_first = match null_ordering { + NullOrdering::SortNullsUnspecified => { + bail!("Unspecified null ordering is not yet supported") + } + NullOrdering::SortNullsFirst => true, + NullOrdering::SortNullsLast => false, + }; + + // https://github.com/Eventual-Inc/Daft/blob/7922d2d810ff92b00008d877aa9a6553bc0dedab/src/daft-core/src/utils/mod.rs#L10-L19 + let sort_nulls_first = is_descending; + + if sort_nulls_first != tentative_sort_nulls_first { + warn!("Ignoring nulls_first {sort_nulls_first}; not yet implemented"); + } + + sort_by.push(child); + descending.push(is_descending); + nulls_first.push(sort_nulls_first); + } + + let plan = plan.sort(sort_by, descending, nulls_first)?; + + Ok(plan) +} diff --git a/tests/connect/test_sort.py b/tests/connect/test_sort.py new file mode 100644 index 0000000000..653510db18 --- /dev/null +++ b/tests/connect/test_sort.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from pyspark.sql.functions import col + + +def test_sort(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Sort the DataFrame by 'id' column in descending order + df_sorted = df.sort(col("id").desc()) + + # Verify the DataFrame is sorted correctly + df_pandas = df.toPandas() + df_sorted_pandas = df_sorted.toPandas() + assert df_sorted_pandas["id"].equals(df_pandas["id"].sort_values(ascending=False).reset_index(drop=True)), "Data should be sorted in descending order"