diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index ab4d386f05b0..09c6e0351ed3 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -42,7 +42,7 @@ object_store = { workspace = true } pbjson-types = "0.7" # TODO use workspace version prost = "0.13" -substrait = { version = "0.49", features = ["serde"] } +substrait = { version = "0.50", features = ["serde"] } url = { workspace = true } [dev-dependencies] diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a9e411e35ae8..9f98fdace6a0 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -62,7 +62,7 @@ use datafusion::logical_expr::{ col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion::prelude::JoinType; +use datafusion::prelude::{lit, JoinType}; use datafusion::sql::TableReference; use datafusion::{ error::Result, logical_expr::utils::split_conjunction, prelude::Column, @@ -98,7 +98,7 @@ use substrait::proto::{ sort_field::{SortDirection, SortKind::*}, AggregateFunction, Expression, NamedStruct, Plan, Rel, RelCommon, Type, }; -use substrait::proto::{ExtendedExpression, FunctionArgument, SortField}; +use substrait::proto::{fetch_rel, ExtendedExpression, FunctionArgument, SortField}; use super::state::SubstraitPlanningState; @@ -640,14 +640,27 @@ pub async fn from_substrait_rel( let input = LogicalPlanBuilder::from( from_substrait_rel(state, input, extensions).await?, ); - let offset = fetch.offset as usize; - // -1 means that ALL records should be returned - let count = if fetch.count == -1 { - None - } else { - Some(fetch.count as usize) + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let offset = match &fetch.offset_mode { + Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)), + Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => Some( + from_substrait_rex(state, expr, &empty_schema, extensions) + .await?, + ), + None => None, + }; + let count = match &fetch.count_mode { + Some(fetch_rel::CountMode::Count(count)) => { + // -1 means that ALL records should be returned, equivalent to None + (*count != -1).then(|| lit(*count)) + } + Some(fetch_rel::CountMode::CountExpr(expr)) => Some( + from_substrait_rex(state, expr, &empty_schema, extensions) + .await?, + ), + None => None, }; - input.limit(offset, count)?.build() + input.limit_by_expr(offset, count)?.build() } else { not_impl_err!("Fetch without an input is not valid") } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index a128b90e6889..375cb734f564 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -22,9 +22,7 @@ use std::sync::Arc; use substrait::proto::expression_reference::ExprType; use datafusion::arrow::datatypes::{Field, IntervalUnit}; -use datafusion::logical_expr::{ - Distinct, FetchType, Like, Partitioning, SkipType, TryCast, WindowFrameUnits, -}; +use datafusion::logical_expr::{Distinct, Like, Partitioning, TryCast, WindowFrameUnits}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -45,7 +43,7 @@ use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; use datafusion::arrow::temporal_conversions::NANOSECONDS; use datafusion::common::{ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, - substrait_err, DFSchemaRef, ToDFSchema, + substrait_err, DFSchema, DFSchemaRef, ToDFSchema, }; #[allow(unused_imports)] use datafusion::logical_expr::expr::{ @@ -69,7 +67,8 @@ use substrait::proto::read_rel::VirtualTable; use substrait::proto::rel_common::EmitKind; use substrait::proto::rel_common::EmitKind::Emit; use substrait::proto::{ - rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, RelCommon, + fetch_rel, rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, + RelCommon, }; use substrait::{ proto::{ @@ -333,19 +332,31 @@ pub fn to_substrait_rel( } LogicalPlan::Limit(limit) => { let input = to_substrait_rel(limit.input.as_ref(), state, extensions)?; - let FetchType::Literal(fetch) = limit.get_fetch_type()? else { - return not_impl_err!("Non-literal limit fetch"); - }; - let SkipType::Literal(skip) = limit.get_skip_type()? else { - return not_impl_err!("Non-literal limit skip"); - }; + let empty_schema = Arc::new(DFSchema::empty()); + let offset_mode = limit + .skip + .as_ref() + .map(|expr| { + to_substrait_rex(state, expr.as_ref(), &empty_schema, 0, extensions) + }) + .transpose()? + .map(Box::new) + .map(fetch_rel::OffsetMode::OffsetExpr); + let count_mode = limit + .fetch + .as_ref() + .map(|expr| { + to_substrait_rex(state, expr.as_ref(), &empty_schema, 0, extensions) + }) + .transpose()? + .map(Box::new) + .map(fetch_rel::CountMode::CountExpr); Ok(Box::new(Rel { rel_type: Some(RelType::Fetch(Box::new(FetchRel { common: None, input: Some(input), - offset: skip as i64, - // use -1 to signal that ALL records should be returned - count: fetch.map(|f| f as i64).unwrap_or(-1), + offset_mode, + count_mode, advanced_extension: None, }))), })) diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index f836dea03c61..1291bbd6a244 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -240,17 +240,20 @@ async fn select_with_filter_bool_expr() -> Result<()> { #[tokio::test] async fn select_with_limit() -> Result<()> { - roundtrip_fill_na("SELECT * FROM data LIMIT 100").await + roundtrip_fill_na("SELECT * FROM data LIMIT 100").await?; + roundtrip_fill_na("SELECT * FROM data LIMIT 98+100/50").await } #[tokio::test] async fn select_without_limit() -> Result<()> { - roundtrip_fill_na("SELECT * FROM data OFFSET 10").await + roundtrip_fill_na("SELECT * FROM data OFFSET 10").await?; + roundtrip_fill_na("SELECT * FROM data OFFSET 5+7-2").await } #[tokio::test] async fn select_with_limit_offset() -> Result<()> { - roundtrip("SELECT * FROM data LIMIT 200 OFFSET 10").await + roundtrip("SELECT * FROM data LIMIT 200 OFFSET 10").await?; + roundtrip("SELECT * FROM data LIMIT 100+100 OFFSET 20/2").await } #[tokio::test]