diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 918ebccbeb70..4a5c156e28ac 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -29,13 +29,12 @@ use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; use crate::logical_expr::{ - Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Window, + Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Values, Window, }; use crate::logical_expr::{ Expr, LogicalPlan, Partitioning as LogicalPartitioning, PlanType, Repartition, UserDefinedLogicalNode, }; -use crate::logical_expr::{Limit, Values}; use crate::physical_expr::{create_physical_expr, create_physical_exprs}; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; @@ -78,8 +77,8 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, Extension, Filter, JoinType, RecursiveQuery, SortExpr, - StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, Extension, FetchType, Filter, JoinType, RecursiveQuery, + SkipType, SortExpr, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; @@ -796,8 +795,20 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Subquery(_) => todo!(), LogicalPlan::SubqueryAlias(_) => children.one()?, - LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + LogicalPlan::Limit(limit) => { let input = children.one()?; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return not_impl_err!( + "Unsupported OFFSET expression: {:?}", + limit.skip + ); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return not_impl_err!( + "Unsupported LIMIT expression: {:?}", + limit.fetch + ); + }; // GlobalLimitExec requires a single partition for input let input = if input.output_partitioning().partition_count() == 1 { @@ -806,13 +817,13 @@ impl DefaultPhysicalPlanner { // Apply a LocalLimitExec to each partition. The optimizer will also insert // a CoalescePartitionsExec between the GlobalLimitExec and LocalLimitExec if let Some(fetch) = fetch { - Arc::new(LocalLimitExec::new(input, *fetch + skip)) + Arc::new(LocalLimitExec::new(input, fetch + skip)) } else { input } }; - Arc::new(GlobalLimitExec::new(input, *skip, *fetch)) + Arc::new(GlobalLimitExec::new(input, skip, fetch)) } LogicalPlan::Unnest(Unnest { list_type_columns, diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 2b45d0ed600b..6c4e3c66e397 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -81,7 +81,7 @@ use datafusion::{ runtime_env::RuntimeEnv, }, logical_expr::{ - Expr, Extension, Limit, LogicalPlan, Sort, UserDefinedLogicalNode, + Expr, Extension, LogicalPlan, Sort, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, }, optimizer::{OptimizerConfig, OptimizerRule}, @@ -98,7 +98,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; use datafusion_expr::tree_node::replace_sort_expression; -use datafusion_expr::{Projection, SortExpr}; +use datafusion_expr::{FetchType, Projection, SortExpr}; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; @@ -361,28 +361,28 @@ impl OptimizerRule for TopKOptimizerRule { // Note: this code simply looks for the pattern of a Limit followed by a // Sort and replaces it by a TopK node. It does not handle many // edge cases (e.g multiple sort columns, sort ASC / DESC), etc. - if let LogicalPlan::Limit(Limit { - fetch: Some(fetch), - input, + let LogicalPlan::Limit(ref limit) = plan else { + return Ok(Transformed::no(plan)); + }; + let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else { + return Ok(Transformed::no(plan)); + }; + + if let LogicalPlan::Sort(Sort { + ref expr, + ref input, .. - }) = &plan + }) = limit.input.as_ref() { - if let LogicalPlan::Sort(Sort { - ref expr, - ref input, - .. - }) = **input - { - if expr.len() == 1 { - // we found a sort with a single sort expr, replace with a a TopK - return Ok(Transformed::yes(LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode { - k: *fetch, - input: input.as_ref().clone(), - expr: expr[0].clone(), - }), - }))); - } + if expr.len() == 1 { + // we found a sort with a single sort expr, replace with a a TopK + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: fetch, + input: input.as_ref().clone(), + expr: expr[0].clone(), + }), + }))); } } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 47cc947be3ca..d6d5c3e2931c 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -306,11 +306,14 @@ impl NamePreserver { /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan pub fn new(plan: &LogicalPlan) -> Self { Self { - // The schema of Filter, Join and TableScan nodes comes from their inputs rather than - // their expressions, so there is no need to use aliases to preserve expression names. + // The expressions of these plans do not contribute to their output schema, + // so there is no need to preserve expression names to prevent a schema change. use_alias: !matches!( plan, - LogicalPlan::Filter(_) | LogicalPlan::Join(_) | LogicalPlan::TableScan(_) + LogicalPlan::Filter(_) + | LogicalPlan::Join(_) + | LogicalPlan::TableScan(_) + | LogicalPlan::Limit(_) ), } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index d2ecd56cdc23..cef05b6f8814 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -40,7 +40,7 @@ use crate::utils::{ find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, + and, binary_expr, lit, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, TableProviderFilterPushDown, TableSource, WriteOp, }; @@ -512,9 +512,22 @@ impl LogicalPlanBuilder { /// `fetch` - Maximum number of rows to fetch, after skipping `skip` rows, /// if specified. pub fn limit(self, skip: usize, fetch: Option) -> Result { + let skip_expr = if skip == 0 { + None + } else { + Some(lit(skip as i64)) + }; + let fetch_expr = fetch.map(|f| lit(f as i64)); + self.limit_by_expr(skip_expr, fetch_expr) + } + + /// Limit the number of rows returned + /// + /// Similar to `limit` but uses expressions for `skip` and `fetch` + pub fn limit_by_expr(self, skip: Option, fetch: Option) -> Result { Ok(Self::new(LogicalPlan::Limit(Limit { - skip, - fetch, + skip: skip.map(Box::new), + fetch: fetch.map(Box::new), input: self.plan, }))) } diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 26d54803d403..0287846862af 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -549,11 +549,13 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { let mut object = serde_json::json!( { "Node Type": "Limit", - "Skip": skip, } ); + if let Some(s) = skip { + object["Skip"] = s.to_string().into() + }; if let Some(f) = fetch { - object["Fetch"] = serde_json::Value::Number((*f).into()); + object["Fetch"] = f.to_string().into() }; object } diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index da44cfb010d7..18ac3f2ab9cb 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -36,9 +36,9 @@ pub use ddl::{ pub use dml::{DmlStatement, WriteOp}; pub use plan::{ projection_schema, Aggregate, Analyze, ColumnUnnestList, CrossJoin, DescribeTable, - Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, + Distinct, DistinctOn, EmptyRelation, Explain, Extension, FetchType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, - Projection, RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery, + Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d8dfe7b56e40..e0aae4cb7448 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -49,7 +49,8 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, - FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions, + FunctionalDependencies, ParamValues, Result, ScalarValue, TableReference, + UnnestOptions, }; use indexmap::IndexSet; @@ -960,11 +961,21 @@ impl LogicalPlan { .map(LogicalPlan::SubqueryAlias) } LogicalPlan::Limit(Limit { skip, fetch, .. }) => { - self.assert_no_expressions(expr)?; + let old_expr_len = skip.iter().chain(fetch.iter()).count(); + if old_expr_len != expr.len() { + return internal_err!( + "Invalid number of new Limit expressions: expected {}, got {}", + old_expr_len, + expr.len() + ); + } + // Pop order is same as the order returned by `LogicalPlan::expressions()` + let new_skip = skip.as_ref().and(expr.pop()); + let new_fetch = fetch.as_ref().and(expr.pop()); let input = self.only_input(inputs)?; Ok(LogicalPlan::Limit(Limit { - skip: *skip, - fetch: *fetch, + skip: new_skip.map(Box::new), + fetch: new_fetch.map(Box::new), input: Arc::new(input), })) } @@ -1339,7 +1350,10 @@ impl LogicalPlan { LogicalPlan::RecursiveQuery(_) => None, LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), - LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, + LogicalPlan::Limit(limit) => match limit.get_fetch_type() { + Ok(FetchType::Literal(s)) => s, + _ => None, + }, LogicalPlan::Distinct( Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), ) => input.max_rows(), @@ -1909,16 +1923,20 @@ impl LogicalPlan { ) } }, - LogicalPlan::Limit(Limit { - ref skip, - ref fetch, - .. - }) => { + LogicalPlan::Limit(limit) => { + // Attempt to display `skip` and `fetch` as literals if possible, otherwise as expressions. + let skip_str = match limit.get_skip_type() { + Ok(SkipType::Literal(n)) => n.to_string(), + _ => limit.skip.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()), + }; + let fetch_str = match limit.get_fetch_type() { + Ok(FetchType::Literal(Some(n))) => n.to_string(), + Ok(FetchType::Literal(None)) => "None".to_string(), + _ => limit.fetch.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()) + }; write!( f, - "Limit: skip={}, fetch={}", - skip, - fetch.map_or_else(|| "None".to_string(), |x| x.to_string()) + "Limit: skip={}, fetch={}", skip_str,fetch_str, ) } LogicalPlan::Subquery(Subquery { .. }) => { @@ -2778,14 +2796,71 @@ impl PartialOrd for Extension { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Limit { /// Number of rows to skip before fetch - pub skip: usize, + pub skip: Option>, /// Maximum number of rows to fetch, /// None means fetching all rows - pub fetch: Option, + pub fetch: Option>, /// The logical plan pub input: Arc, } +/// Different types of skip expression in Limit plan. +pub enum SkipType { + /// The skip expression is a literal value. + Literal(usize), + /// Currently only supports expressions that can be folded into constants. + UnsupportedExpr, +} + +/// Different types of fetch expression in Limit plan. +pub enum FetchType { + /// The fetch expression is a literal value. + /// `Literal(None)` means the fetch expression is not provided. + Literal(Option), + /// Currently only supports expressions that can be folded into constants. + UnsupportedExpr, +} + +impl Limit { + /// Get the skip type from the limit plan. + pub fn get_skip_type(&self) -> Result { + match self.skip.as_deref() { + Some(expr) => match *expr { + Expr::Literal(ScalarValue::Int64(s)) => { + // `skip = NULL` is equivalent to `skip = 0` + let s = s.unwrap_or(0); + if s >= 0 { + Ok(SkipType::Literal(s as usize)) + } else { + plan_err!("OFFSET must be >=0, '{}' was provided", s) + } + } + _ => Ok(SkipType::UnsupportedExpr), + }, + // `skip = None` is equivalent to `skip = 0` + None => Ok(SkipType::Literal(0)), + } + } + + /// Get the fetch type from the limit plan. + pub fn get_fetch_type(&self) -> Result { + match self.fetch.as_deref() { + Some(expr) => match *expr { + Expr::Literal(ScalarValue::Int64(Some(s))) => { + if s >= 0 { + Ok(FetchType::Literal(Some(s as usize))) + } else { + plan_err!("LIMIT must be >= 0, '{}' was provided", s) + } + } + Expr::Literal(ScalarValue::Int64(None)) => Ok(FetchType::Literal(None)), + _ => Ok(FetchType::UnsupportedExpr), + }, + None => Ok(FetchType::Literal(None)), + } + } +} + /// Removes duplicate rows from the input #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum Distinct { diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 606868e75abf..b8d7043d7746 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -43,6 +43,7 @@ use crate::{ Repartition, Sort, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, Window, }; +use std::ops::Deref; use std::sync::Arc; use crate::expr::{Exists, InSubquery}; @@ -515,12 +516,16 @@ impl LogicalPlan { .chain(select_expr.iter()) .chain(sort_expr.iter().flatten().map(|sort| &sort.expr)) .apply_until_stop(f), + LogicalPlan::Limit(Limit { skip, fetch, .. }) => skip + .iter() + .chain(fetch.iter()) + .map(|e| e.deref()) + .apply_until_stop(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) | LogicalPlan::CrossJoin(_) | LogicalPlan::Analyze(_) @@ -726,13 +731,32 @@ impl LogicalPlan { schema, })) }), + LogicalPlan::Limit(Limit { skip, fetch, input }) => { + let skip = skip.map(|e| *e); + let fetch = fetch.map(|e| *e); + map_until_stop_and_collect!( + skip.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }), + fetch, + fetch.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }) + )? + .update_data(|(skip, fetch)| { + LogicalPlan::Limit(Limit { + skip: skip.map(Box::new), + fetch: fetch.map(Box::new), + input, + }) + }) + } // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) | LogicalPlan::CrossJoin(_) | LogicalPlan::Analyze(_) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e5d280289342..36b72233b5af 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -51,8 +51,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, LogicalPlan, Operator, - Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits, + AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, Limit, LogicalPlan, + Operator, Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; /// Performs type coercion by determining the schema @@ -169,6 +170,7 @@ impl<'a> TypeCoercionRewriter<'a> { match plan { LogicalPlan::Join(join) => self.coerce_join(join), LogicalPlan::Union(union) => Self::coerce_union(union), + LogicalPlan::Limit(limit) => Self::coerce_limit(limit), _ => Ok(plan), } } @@ -230,6 +232,37 @@ impl<'a> TypeCoercionRewriter<'a> { })) } + /// Coerce the fetch and skip expression to Int64 type. + fn coerce_limit(limit: Limit) -> Result { + fn coerce_limit_expr( + expr: Expr, + schema: &DFSchema, + expr_name: &str, + ) -> Result { + let dt = expr.get_type(schema)?; + if dt.is_integer() || dt.is_null() { + expr.cast_to(&DataType::Int64, schema) + } else { + plan_err!("Expected {expr_name} to be an integer or null, but got {dt:?}") + } + } + + let empty_schema = DFSchema::empty(); + let new_fetch = limit + .fetch + .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT")) + .transpose()?; + let new_skip = limit + .skip + .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET")) + .transpose()?; + Ok(LogicalPlan::Limit(Limit { + input: limit.input, + fetch: new_fetch.map(Box::new), + skip: new_skip.map(Box::new), + })) + } + fn coerce_join_filter(&self, expr: Expr) -> Result { let expr_type = expr.get_type(self.schema)?; match expr_type { diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 7f918c03e3ac..baf449a045eb 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -31,7 +31,9 @@ use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; -use datafusion_expr::{expr, lit, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{ + expr, lit, EmptyRelation, Expr, FetchType, LogicalPlan, LogicalPlanBuilder, +}; use datafusion_physical_expr::execution_props::ExecutionProps; /// This struct rewrite the sub query plan by pull up the correlated @@ -327,16 +329,15 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = match (self.exists_sub_query, self.join_filters.is_empty()) { // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) - (true, false) => Transformed::yes( - if limit.fetch.filter(|limit_row| *limit_row == 0).is_some() { + (true, false) => Transformed::yes(match limit.get_fetch_type()? { + FetchType::Literal(Some(0)) => { LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: Arc::clone(limit.input.schema()), }) - } else { - LogicalPlanBuilder::from((*limit.input).clone()).build()? - }, - ), + } + _ => LogicalPlanBuilder::from((*limit.input).clone()).build()?, + }), _ => Transformed::no(plan), }; if let Some(input_map) = input_expr_map { diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 25304d4ccafa..829d4c2d2217 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -20,7 +20,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; -use datafusion_expr::logical_plan::{EmptyRelation, LogicalPlan}; +use datafusion_expr::logical_plan::{EmptyRelation, FetchType, LogicalPlan, SkipType}; use std::sync::Arc; /// Optimizer rule to replace `LIMIT 0` or `LIMIT` whose ancestor LIMIT's skip is @@ -63,8 +63,13 @@ impl OptimizerRule for EliminateLimit { > { match plan { LogicalPlan::Limit(limit) => { - if let Some(fetch) = limit.fetch { - if fetch == 0 { + // Only supports rewriting for literal fetch + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + + if let Some(v) = fetch { + if v == 0 { return Ok(Transformed::yes(LogicalPlan::EmptyRelation( EmptyRelation { produce_one_row: false, @@ -72,11 +77,10 @@ impl OptimizerRule for EliminateLimit { }, ))); } - } else if limit.skip == 0 { - // input also can be Limit, so we should apply again. - return Ok(self - .rewrite(Arc::unwrap_or_clone(limit.input), _config) - .unwrap()); + } else if matches!(limit.get_skip_type()?, SkipType::Literal(0)) { + // If fetch is `None` and skip is 0, then Limit takes no effect and + // we can remove it. Its input also can be Limit, so we should apply again. + return self.rewrite(Arc::unwrap_or_clone(limit.input), _config); } Ok(Transformed::no(LogicalPlan::Limit(limit))) } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 6ed77387046e..bf5ce0531e06 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -27,6 +27,7 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::utils::combine_limit; use datafusion_common::Result; use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan}; +use datafusion_expr::{lit, FetchType, SkipType}; /// Optimization rule that tries to push down `LIMIT`. /// @@ -56,16 +57,27 @@ impl OptimizerRule for PushDownLimit { return Ok(Transformed::no(plan)); }; - let Limit { skip, fetch, input } = limit; + // Currently only rewrite if skip and fetch are both literals + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; // Merge the Parent Limit and the Child Limit. - if let LogicalPlan::Limit(child) = input.as_ref() { - let (skip, fetch) = - combine_limit(limit.skip, limit.fetch, child.skip, child.fetch); - + if let LogicalPlan::Limit(child) = limit.input.as_ref() { + let SkipType::Literal(child_skip) = child.get_skip_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + let FetchType::Literal(child_fetch) = child.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + + let (skip, fetch) = combine_limit(skip, fetch, child_skip, child_fetch); let plan = LogicalPlan::Limit(Limit { - skip, - fetch, + skip: Some(Box::new(lit(skip as i64))), + fetch: fetch.map(|f| Box::new(lit(f as i64))), input: Arc::clone(&child.input), }); @@ -75,14 +87,10 @@ impl OptimizerRule for PushDownLimit { // no fetch to push, so return the original plan let Some(fetch) = fetch else { - return Ok(Transformed::no(LogicalPlan::Limit(Limit { - skip, - fetch, - input, - }))); + return Ok(Transformed::no(LogicalPlan::Limit(limit))); }; - match Arc::unwrap_or_clone(input) { + match Arc::unwrap_or_clone(limit.input) { LogicalPlan::TableScan(mut scan) => { let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; let new_fetch = scan @@ -162,8 +170,8 @@ impl OptimizerRule for PushDownLimit { .into_iter() .map(|child| { LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(fetch + skip), + skip: None, + fetch: Some(Box::new(lit((fetch + skip) as i64))), input: Arc::new(child.clone()), }) }) @@ -203,8 +211,8 @@ impl OptimizerRule for PushDownLimit { /// ``` fn make_limit(skip: usize, fetch: usize, input: Arc) -> LogicalPlan { LogicalPlan::Limit(Limit { - skip, - fetch: Some(fetch), + skip: Some(Box::new(lit(skip as i64))), + fetch: Some(Box::new(lit(fetch as i64))), input, }) } @@ -224,11 +232,7 @@ fn original_limit( fetch: usize, input: LogicalPlan, ) -> Result> { - Ok(Transformed::no(LogicalPlan::Limit(Limit { - skip, - fetch: Some(fetch), - input: Arc::new(input), - }))) + Ok(Transformed::no(make_limit(skip, fetch, Arc::new(input)))) } /// Returns the a transformed limit @@ -237,11 +241,7 @@ fn transformed_limit( fetch: usize, input: LogicalPlan, ) -> Result> { - Ok(Transformed::yes(LogicalPlan::Limit(Limit { - skip, - fetch: Some(fetch), - input: Arc::new(input), - }))) + Ok(Transformed::yes(make_limit(skip, fetch, Arc::new(input)))) } /// Adds a limit to the inputs of a join, if possible diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 4adbb9318d51..73df506397b1 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -62,13 +62,13 @@ use datafusion_expr::{ logical_plan::{ builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, CrossJoin, DdlStatement, Distinct, - EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, Projection, - Repartition, Sort, SubqueryAlias, TableScan, Values, Window, + EmptyRelation, Extension, Join, JoinConstraint, Prepare, Projection, Repartition, + Sort, SubqueryAlias, TableScan, Values, Window, }, DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, WindowUDF, }; -use datafusion_expr::{AggregateUDF, ColumnUnnestList, Unnest}; +use datafusion_expr::{AggregateUDF, ColumnUnnestList, FetchType, SkipType, Unnest}; use self::to_proto::{serialize_expr, serialize_exprs}; use crate::logical_plan::to_proto::serialize_sorts; @@ -1265,17 +1265,28 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Limit(Limit { input, skip, fetch }) => { + LogicalPlan::Limit(limit) => { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), + limit.input.as_ref(), extension_codec, )?; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return Err(proto_error( + "LogicalPlan::Limit only supports literal skip values", + )); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Err(proto_error( + "LogicalPlan::Limit only supports literal fetch values", + )); + }; + Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Limit(Box::new( protobuf::LimitNode { input: Some(Box::new(input)), - skip: *skip as i64, + skip: skip as i64, fetch: fetch.unwrap_or(i64::MAX as usize) as i64, }, ))), diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 54945ec43d10..842a1c0cbec1 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -19,15 +19,14 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{not_impl_err, plan_err, Constraints, Result, ScalarValue}; +use datafusion_common::{not_impl_err, Constraints, DFSchema, Result}; use datafusion_expr::expr::Sort; use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, - Operator, + CreateMemoryTable, DdlStatement, Distinct, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ Expr as SQLExpr, Offset as SQLOffset, OrderBy, OrderByExpr, Query, SelectInto, - SetExpr, Value, + SetExpr, }; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -85,35 +84,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(input); } - let skip = match skip { - Some(skip_expr) => { - let expr = self.sql_to_expr( - skip_expr.value, - input.schema(), - &mut PlannerContext::new(), - )?; - let n = get_constant_result(&expr, "OFFSET")?; - convert_usize_with_check(n, "OFFSET") - } - _ => Ok(0), - }?; - - let fetch = match fetch { - Some(limit_expr) - if limit_expr != sqlparser::ast::Expr::Value(Value::Null) => - { - let expr = self.sql_to_expr( - limit_expr, - input.schema(), - &mut PlannerContext::new(), - )?; - let n = get_constant_result(&expr, "LIMIT")?; - Some(convert_usize_with_check(n, "LIMIT")?) - } - _ => None, - }; - - LogicalPlanBuilder::from(input).limit(skip, fetch)?.build() + // skip and fetch expressions are not allowed to reference columns from the input plan + let empty_schema = DFSchema::empty(); + + let skip = skip + .map(|o| self.sql_to_expr(o.value, &empty_schema, &mut PlannerContext::new())) + .transpose()?; + let fetch = fetch + .map(|e| self.sql_to_expr(e, &empty_schema, &mut PlannerContext::new())) + .transpose()?; + LogicalPlanBuilder::from(input) + .limit_by_expr(skip, fetch)? + .build() } /// Wrap the logical in a sort @@ -159,50 +141,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } -/// Retrieves the constant result of an expression, evaluating it if possible. -/// -/// This function takes an expression and an argument name as input and returns -/// a `Result` indicating either the constant result of the expression or an -/// error if the expression cannot be evaluated. -/// -/// # Arguments -/// -/// * `expr` - An `Expr` representing the expression to evaluate. -/// * `arg_name` - The name of the argument for error messages. -/// -/// # Returns -/// -/// * `Result` - An `Ok` variant containing the constant result if evaluation is successful, -/// or an `Err` variant containing an error message if evaluation fails. -/// -/// tracks a more general solution -fn get_constant_result(expr: &Expr, arg_name: &str) -> Result { - match expr { - Expr::Literal(ScalarValue::Int64(Some(s))) => Ok(*s), - Expr::BinaryExpr(binary_expr) => { - let lhs = get_constant_result(&binary_expr.left, arg_name)?; - let rhs = get_constant_result(&binary_expr.right, arg_name)?; - let res = match binary_expr.op { - Operator::Plus => lhs + rhs, - Operator::Minus => lhs - rhs, - Operator::Multiply => lhs * rhs, - _ => return plan_err!("Unsupported operator for {arg_name} clause"), - }; - Ok(res) - } - _ => plan_err!("Unexpected expression in {arg_name} clause"), - } -} - -/// Converts an `i64` to `usize`, performing a boundary check. -fn convert_usize_with_check(n: i64, arg_name: &str) -> Result { - if n < 0 { - plan_err!("{arg_name} must be >= 0, '{n}' was provided.") - } else { - Ok(n as usize) - } -} - /// Returns the order by expressions from the query. fn to_order_by_exprs(order_by: Option) -> Result> { let Some(OrderBy { exprs, interpolate }) = order_by else { diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 037748035fbf..0147a607567b 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -343,20 +343,16 @@ impl Unparser<'_> { relation, ); } - - if let Some(fetch) = limit.fetch { + if let Some(fetch) = &limit.fetch { let Some(query) = query.as_mut() else { return internal_err!( "Limit operator only valid in a statement context." ); }; - query.limit(Some(ast::Expr::Value(ast::Value::Number( - fetch.to_string(), - false, - )))); + query.limit(Some(self.expr_to_sql(fetch)?)); } - if limit.skip > 0 { + if let Some(skip) = &limit.skip { let Some(query) = query.as_mut() else { return internal_err!( "Offset operator only valid in a statement context." @@ -364,10 +360,7 @@ impl Unparser<'_> { }; query.offset(Some(ast::Offset { rows: ast::OffsetRows::None, - value: ast::Expr::Value(ast::Value::Number( - limit.skip.to_string(), - false, - )), + value: self.expr_to_sql(skip)?, })); } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index e7b96199511a..9ed084eec249 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -1016,7 +1016,7 @@ fn test_without_offset() { #[test] fn test_with_offset0() { - sql_round_trip(MySqlDialect {}, "select 1 offset 0", "SELECT 1"); + sql_round_trip(MySqlDialect {}, "select 1 offset 0", "SELECT 1 OFFSET 0"); } #[test] diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 9910ca8da71f..f2ab4135aaa7 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -581,9 +581,32 @@ select * from (select 1 a union all select 2) b order by a limit 1; 1 # select limit clause invalid -statement error DataFusion error: Error during planning: LIMIT must be >= 0, '\-1' was provided\. +statement error Error during planning: LIMIT must be >= 0, '-1' was provided select * from (select 1 a union all select 2) b order by a limit -1; +statement error Error during planning: OFFSET must be >=0, '-1' was provided +select * from (select 1 a union all select 2) b order by a offset -1; + +statement error Unsupported LIMIT expression +select * from (values(1),(2)) limit (select 1); + +statement error Unsupported OFFSET expression +select * from (values(1),(2)) offset (select 1); + +# disallow non-integer limit/offset +statement error Expected LIMIT to be an integer or null, but got Float64 +select * from (values(1),(2)) limit 0.5; + +statement error Expected OFFSET to be an integer or null, but got Utf8 +select * from (values(1),(2)) offset '1'; + +# test with different integer types +query I +select * from (values (1), (2), (3), (4)) limit 2::int OFFSET 1::tinyint +---- +2 +3 + # select limit with basic arithmetic query I select * from (select 1 a union all select 2) b order by a limit 1+1; @@ -597,13 +620,38 @@ select * from (values (1)) LIMIT 10*100; ---- 1 -# More complex expressions in the limit is not supported yet. -# See issue: https://github.com/apache/datafusion/issues/9821 -statement error DataFusion error: Error during planning: Unsupported operator for LIMIT clause +# select limit with complex arithmetic +query I select * from (values (1)) LIMIT 100/10; +---- +1 -# More complex expressions in the limit is not supported yet. -statement error DataFusion error: Error during planning: Unexpected expression in LIMIT clause +# test constant-folding of LIMIT expr +query I +select * from (values (1), (2), (3), (4)) LIMIT abs(-4) + 4 / -2; -- LIMIT 2 +---- +1 +2 + +# test constant-folding of OFFSET expr +query I +select * from (values (1), (2), (3), (4)) OFFSET abs(-4) + 4 / -2; -- OFFSET 2 +---- +3 +4 + +# test constant-folding of LIMIT and OFFSET +query I +select * from (values (1), (2), (3), (4)) + -- LIMIT 2 + LIMIT abs(-4) + -1 * 2 + -- OFFSET 1 + OFFSET case when 1 < 2 then 1 else 0 end; +---- +2 +3 + +statement error Schema error: No field named column1. select * from (values (1)) LIMIT cast(column1 as tinyint); # select limit clause @@ -613,6 +661,13 @@ select * from (select 1 a union all select 2) b order by a limit null; 1 2 +# offset null takes no effect +query I +select * from (select 1 a union all select 2) b order by a offset null; +---- +1 +2 + # select limit clause query I select * from (select 1 a union all select 2) b order by a limit 0; diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 8a8d195507a2..3d5d7cce5673 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -623,8 +623,8 @@ pub async fn from_substrait_rel( from_substrait_rel(ctx, input, extensions).await?, ); let offset = fetch.offset as usize; - // Since protobuf can't directly distinguish `None` vs `0` `None` is encoded as `MAX` - let count = if fetch.count as usize == usize::MAX { + // -1 means that ALL records should be returned + let count = if fetch.count == -1 { None } else { Some(fetch.count as usize) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 7504a287c055..bb50c4b9610f 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -24,7 +24,7 @@ use substrait::proto::expression_reference::ExprType; use arrow_buffer::ToByteSlice; use datafusion::arrow::datatypes::{Field, IntervalUnit}; use datafusion::logical_expr::{ - CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, + CrossJoin, Distinct, FetchType, Like, Partitioning, SkipType, WindowFrameUnits, }; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -326,14 +326,19 @@ pub fn to_substrait_rel( } LogicalPlan::Limit(limit) => { let input = to_substrait_rel(limit.input.as_ref(), ctx, extensions)?; - // Since protobuf can't directly distinguish `None` vs `0` encode `None` as `MAX` - let limit_fetch = limit.fetch.unwrap_or(usize::MAX); + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return not_impl_err!("Non-literal limit fetch"); + }; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return not_impl_err!("Non-literal limit skip"); + }; Ok(Box::new(Rel { rel_type: Some(RelType::Fetch(Box::new(FetchRel { common: None, input: Some(input), - offset: limit.skip as i64, - count: limit_fetch as i64, + offset: skip as i64, + // use -1 to signal that ALL records should be returned + count: fetch.map(|f| f as i64).unwrap_or(-1), advanced_extension: None, }))), }))