diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 6f30845eb810..2bad683dc1bc 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -706,13 +706,6 @@ impl Unparser<'_> { Ok(()) } LogicalPlan::Union(union) => { - if union.inputs.len() != 2 { - return not_impl_err!( - "UNION ALL expected 2 inputs, but found {}", - union.inputs.len() - ); - } - // Covers cases where the UNION is a subquery and the projection is at the top level if select.already_projected() { return self.derive_with_dialect_alias( @@ -729,12 +722,22 @@ impl Unparser<'_> { .map(|input| self.select_to_sql_expr(input, query)) .collect::>>()?; - let union_expr = SetExpr::SetOperation { - op: ast::SetOperator::Union, - set_quantifier: ast::SetQuantifier::All, - left: Box::new(input_exprs[0].clone()), - right: Box::new(input_exprs[1].clone()), - }; + if input_exprs.len() < 2 { + return internal_err!("UNION operator requires at least 2 inputs"); + } + + // Build the union expression tree bottom-up by reversing the order + // note that we are also swapping left and right inputs because of the rev + let union_expr = input_exprs + .into_iter() + .rev() + .reduce(|a, b| SetExpr::SetOperation { + op: ast::SetOperator::Union, + set_quantifier: ast::SetQuantifier::All, + left: Box::new(b), + right: Box::new(a), + }) + .unwrap(); let Some(query) = query.as_mut() else { return internal_err!( diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 24ec7f03deb0..94b4df59ef00 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::*; +use arrow_schema::{DataType, Field, Schema}; use datafusion_common::{assert_contains, DFSchema, DFSchemaRef, Result, TableReference}; use datafusion_expr::test::function_stub::{ count_udaf, max_udaf, min_udaf, sum, sum_udaf, }; use datafusion_expr::{ - col, lit, table_scan, wildcard, Expr, Extension, LogicalPlan, LogicalPlanBuilder, - UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan, + LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, }; use datafusion_functions::unicode; use datafusion_functions_aggregate::grouping::grouping_udaf; @@ -42,7 +42,7 @@ use std::{fmt, vec}; use crate::common::{MockContextProvider, MockSessionState}; use datafusion_expr::builder::{ - table_scan_with_filter_and_fetch, table_scan_with_filters, + project, table_scan_with_filter_and_fetch, table_scan_with_filters, }; use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_nested::extract::array_element_udf; @@ -1615,3 +1615,51 @@ fn test_unparse_extension_to_sql() -> Result<()> { } Ok(()) } + +#[test] +fn test_unparse_optimized_multi_union() -> Result<()> { + let unparser = Unparser::default(); + + let schema = Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + + let dfschema = Arc::new(DFSchema::try_from(schema)?); + + let empty = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: dfschema.clone(), + }); + + let plan = LogicalPlan::Union(Union { + inputs: vec![ + project(empty.clone(), vec![lit(1).alias("x"), lit("a").alias("y")])?.into(), + project(empty.clone(), vec![lit(1).alias("x"), lit("b").alias("y")])?.into(), + project(empty.clone(), vec![lit(2).alias("x"), lit("a").alias("y")])?.into(), + project(empty.clone(), vec![lit(2).alias("x"), lit("c").alias("y")])?.into(), + ], + schema: dfschema.clone(), + }); + + let sql = "SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y"; + + assert_eq!(unparser.plan_to_sql(&plan)?.to_string(), sql); + + let plan = LogicalPlan::Union(Union { + inputs: vec![project( + empty.clone(), + vec![lit(1).alias("x"), lit("a").alias("y")], + )? + .into()], + schema: dfschema.clone(), + }); + + if let Some(err) = plan_to_sql(&plan).err() { + assert_contains!(err.to_string(), "UNION operator requires at least 2 inputs"); + } else { + panic!("Expected error") + } + + Ok(()) +}