From 241de417e6809661891983655bd42d771b39572d Mon Sep 17 00:00:00 2001 From: Rob Tandy Date: Fri, 20 Dec 2024 13:40:20 -0500 Subject: [PATCH 1/3] include FetchRel when producing LogicalPlan from Sort --- .../substrait/src/logical_plan/producer.rs | 35 +++++++++++++++++-- .../tests/cases/roundtrip_logical_plan.rs | 5 +++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 375cb734f564..1f45bc3e83e1 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -368,14 +368,45 @@ pub fn to_substrait_rel( .iter() .map(|e| substrait_sort_field(state, e, sort.input.schema(), extensions)) .collect::>>()?; - Ok(Box::new(Rel { + + let sort_rel = Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { common: None, input: Some(input), sorts: sort_fields, advanced_extension: None, }))), - })) + }); + + match sort.fetch { + Some(_) => { + let empty_schema = Arc::new(DFSchema::empty()); + let count_mode = sort + .fetch + .map(|amount| { + to_substrait_rex( + state, + &Expr::Literal(ScalarValue::Int64(Some(amount as i64))), + &empty_schema, + 0, + extensions, + ) + }) + .transpose()? + .map(Box::new) + .map(fetch_rel::CountMode::CountExpr); + Ok(Box::new(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + common: None, + input: Some(sort_rel), + offset_mode: None, + count_mode, + advanced_extension: None, + }))), + })) + } + None => Ok(sort_rel), + } } LogicalPlan::Aggregate(agg) => { let input = to_substrait_rel(agg.input.as_ref(), state, extensions)?; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 1291bbd6a244..64d901fc371a 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -200,6 +200,11 @@ async fn select_with_filter() -> Result<()> { roundtrip("SELECT * FROM data WHERE a > 1").await } +#[tokio::test] +async fn select_with_filter_sort_limit() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a > 1 ORDER BY b ASC LIMIT 2").await +} + #[tokio::test] async fn select_with_reused_functions() -> Result<()> { let ctx = create_context().await?; From bd4b8a69bbc53de53883655a0a51bda89bde1d9d Mon Sep 17 00:00:00 2001 From: Rob Tandy Date: Fri, 20 Dec 2024 16:54:16 -0500 Subject: [PATCH 2/3] add suggested test --- datafusion/substrait/tests/cases/roundtrip_logical_plan.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 64d901fc371a..6c2b76374757 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -205,6 +205,11 @@ async fn select_with_filter_sort_limit() -> Result<()> { roundtrip("SELECT * FROM data WHERE a > 1 ORDER BY b ASC LIMIT 2").await } +#[tokio::test] +async fn select_with_filter_sort_limit_offset() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a > 1 ORDER BY b ASC LIMIT 2 OFFSET 1").await +} + #[tokio::test] async fn select_with_reused_functions() -> Result<()> { let ctx = create_context().await?; From 73c4d84052675182913d602a5e9a5fc0e49fd61f Mon Sep 17 00:00:00 2001 From: Rob Tandy Date: Sat, 21 Dec 2024 09:39:56 -0500 Subject: [PATCH 3/3] address review feedback --- .../substrait/src/logical_plan/producer.rs | 37 ++++++++----------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 1f45bc3e83e1..57c8bc9681cc 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -361,14 +361,14 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Sort(sort) => { - let input = to_substrait_rel(sort.input.as_ref(), state, extensions)?; - let sort_fields = sort - .expr + LogicalPlan::Sort(datafusion::logical_expr::Sort { expr, input, fetch }) => { + let sort_fields = expr .iter() - .map(|e| substrait_sort_field(state, e, sort.input.schema(), extensions)) + .map(|e| substrait_sort_field(state, e, input.schema(), extensions)) .collect::>>()?; + let input = to_substrait_rel(input.as_ref(), state, extensions)?; + let sort_rel = Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { common: None, @@ -378,23 +378,16 @@ pub fn to_substrait_rel( }))), }); - match sort.fetch { - Some(_) => { - let empty_schema = Arc::new(DFSchema::empty()); - let count_mode = sort - .fetch - .map(|amount| { - to_substrait_rex( - state, - &Expr::Literal(ScalarValue::Int64(Some(amount as i64))), - &empty_schema, - 0, - extensions, - ) - }) - .transpose()? - .map(Box::new) - .map(fetch_rel::CountMode::CountExpr); + match fetch { + Some(amount) => { + let count_mode = + Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::I64(*amount as i64)), + })), + }))); Ok(Box::new(Rel { rel_type: Some(RelType::Fetch(Box::new(FetchRel { common: None,