Skip to content

[SPARK-52462] [SQL] Enforce type coercion before children output deduplication in Union #51172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor

case u @ Union(children, _, _)
// if there are duplicate output columns, give them unique expr ids
if children.exists(c => c.output.map(_.exprId).distinct.length < c.output.length) =>
if (u.allChildrenCompatible &&
conf.getConf(SQLConf.ENFORCE_TYPE_COERCION_BEFORE_UNION_DEDUPLICATION)) &&
children.exists(c => c.output.map(_.exprId).distinct.length < c.output.length) =>
val newChildren = children.map { c =>
if (c.output.map(_.exprId).distinct.length < c.output.length) {
val existingExprIds = mutable.HashSet[ExprId]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,18 +603,23 @@ case class Union(
}

override lazy val resolved: Boolean = {
// allChildrenCompatible needs to be evaluated after childrenResolved
def allChildrenCompatible: Boolean =
children.tail.forall( child =>
// compare the attribute number with the first child
child.output.length == children.head.output.length &&
// compare the data types with the first child
child.output.zip(children.head.output).forall {
case (l, r) => DataType.equalsStructurally(l.dataType, r.dataType, true)
})
children.length > 1 && !(byName || allowMissingCol) && childrenResolved && allChildrenCompatible
}

/**
* Checks whether the child outputs are compatible by using `DataType.equalsStructurally`. Do
* that by comparing the size of the output with the size of the first child's output and by
* comparing output data types with the data types of the first child's output.
*
* This method needs to be evaluated after `childrenResolved`.
*/
def allChildrenCompatible: Boolean = childrenResolved && children.tail.forall { child =>
child.output.length == children.head.output.length &&
child.output.zip(children.head.output).forall {
case (l, r) => DataType.equalsStructurally(l.dataType, r.dataType, true)
}
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): Union =
copy(children = newChildren)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5949,6 +5949,17 @@ object SQLConf {
.createWithDefault(2)
}

val ENFORCE_TYPE_COERCION_BEFORE_UNION_DEDUPLICATION =
buildConf("spark.sql.enforceTypeCoercionBeforeUnionDeduplication.enabled")
.internal()
.doc(
"When set to true, we enforce type coercion to run before deduplication of UNION " +
"children outputs. Otherwise, order is relative to rule ordering."
)
.version("4.1.0")
.booleanConf
.createWithDefault(true)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@ CreateViewCommand `t2`, VALUES (1.0, 1), (2.0, 4) tbl(c1, c2), false, true, Loca
+- LocalRelation [c1#x, c2#x]


-- !query
CREATE TABLE parquetTable (col1 INT, col2 INT, col3 INT, col4 INT) USING parquet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the changes, I don't think which built-in file format matters here, why do we test with 3 formats? shall we just test one?

Copy link
Contributor Author

@mihailoale-db mihailoale-db Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could test just one (parquet for example). But wanted to cover more in case we missed some weird thing here. I guess it doesn't hurt to have more tests.

-- !query analysis
CreateDataSourceTableCommand `spark_catalog`.`default`.`parquetTable`, false


-- !query
CREATE TABLE csvTable (col1 INT, col2 INT, col3 INT, col4 INT) USING csv
-- !query analysis
CreateDataSourceTableCommand `spark_catalog`.`default`.`csvTable`, false


-- !query
CREATE TABLE jsonTable (col1 INT, col2 INT, col3 INT, col4 INT) USING json
-- !query analysis
CreateDataSourceTableCommand `spark_catalog`.`default`.`jsonTable`, false


-- !query
SELECT *
FROM (SELECT * FROM t1
Expand Down Expand Up @@ -241,6 +259,63 @@ Aggregate [sum(v#x) AS sum(v)#x]
+- LocalRelation [v#x]


-- !query
SELECT col1, col2, col3, NULLIF('','') AS col4
FROM parquetTable
UNION ALL
SELECT col2, col2, null AS col3, col4
FROM parquetTable
-- !query analysis
Union false, false
:- Project [col1#x, col2#x, col3#x, cast(col4#x as bigint) AS col4#xL]
: +- Project [col1#x, col2#x, col3#x, nullif(, ) AS col4#x]
: +- SubqueryAlias spark_catalog.default.parquettable
: +- Relation spark_catalog.default.parquettable[col1#x,col2#x,col3#x,col4#x] parquet
+- Project [col2#x, col2#x AS col2#x, col3#x, col4#xL]
+- Project [col2#x, col2#x, cast(col3#x as int) AS col3#x, cast(col4#x as bigint) AS col4#xL]
+- Project [col2#x, col2#x, null AS col3#x, col4#x]
+- SubqueryAlias spark_catalog.default.parquettable
+- Relation spark_catalog.default.parquettable[col1#x,col2#x,col3#x,col4#x] parquet


-- !query
SELECT col1, col2, col3, NULLIF('','') AS col4
FROM csvTable
UNION ALL
SELECT col2, col2, null AS col3, col4
FROM csvTable
-- !query analysis
Union false, false
:- Project [col1#x, col2#x, col3#x, cast(col4#x as bigint) AS col4#xL]
: +- Project [col1#x, col2#x, col3#x, nullif(, ) AS col4#x]
: +- SubqueryAlias spark_catalog.default.csvtable
: +- Relation spark_catalog.default.csvtable[col1#x,col2#x,col3#x,col4#x] csv
+- Project [col2#x, col2#x AS col2#x, col3#x, col4#xL]
+- Project [col2#x, col2#x, cast(col3#x as int) AS col3#x, cast(col4#x as bigint) AS col4#xL]
+- Project [col2#x, col2#x, null AS col3#x, col4#x]
+- SubqueryAlias spark_catalog.default.csvtable
+- Relation spark_catalog.default.csvtable[col1#x,col2#x,col3#x,col4#x] csv


-- !query
SELECT col1, col2, col3, NULLIF('','') AS col4
FROM jsonTable
UNION ALL
SELECT col2, col2, null AS col3, col4
FROM jsonTable
-- !query analysis
Union false, false
:- Project [col1#x, col2#x, col3#x, cast(col4#x as bigint) AS col4#xL]
: +- Project [col1#x, col2#x, col3#x, nullif(, ) AS col4#x]
: +- SubqueryAlias spark_catalog.default.jsontable
: +- Relation spark_catalog.default.jsontable[col1#x,col2#x,col3#x,col4#x] json
+- Project [col2#x, col2#x AS col2#x, col3#x, col4#xL]
+- Project [col2#x, col2#x, cast(col3#x as int) AS col3#x, cast(col4#x as bigint) AS col4#xL]
+- Project [col2#x, col2#x, null AS col3#x, col4#x]
+- SubqueryAlias spark_catalog.default.jsontable
+- Relation spark_catalog.default.jsontable[col1#x,col2#x,col3#x,col4#x] json


-- !query
DROP VIEW IF EXISTS t1
-- !query analysis
Expand Down Expand Up @@ -275,3 +350,24 @@ DropTempViewCommand p2
DROP VIEW IF EXISTS p3
-- !query analysis
DropTempViewCommand p3


-- !query
DROP TABLE IF EXISTS parquetTable
-- !query analysis
DropTable true, false
+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.parquetTable


-- !query
DROP TABLE IF EXISTS csvTable
-- !query analysis
DropTable true, false
+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.csvTable


-- !query
DROP TABLE IF EXISTS jsonTable
-- !query analysis
DropTable true, false
+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.jsonTable
25 changes: 25 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/union.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (1, 'a'), (2, 'b') tbl(c1, c2);
CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2);
CREATE TABLE parquetTable (col1 INT, col2 INT, col3 INT, col4 INT) USING parquet;
CREATE TABLE csvTable (col1 INT, col2 INT, col3 INT, col4 INT) USING csv;
CREATE TABLE jsonTable (col1 INT, col2 INT, col3 INT, col4 INT) USING json;

-- Simple Union
SELECT *
Expand Down Expand Up @@ -59,10 +62,32 @@ SELECT SUM(t.v) FROM (
SELECT v + v AS v FROM t3
) t;

-- SPARK-52462: UNION should produce consistent results with different underlying table providers.
SELECT col1, col2, col3, NULLIF('','') AS col4
FROM parquetTable
UNION ALL
SELECT col2, col2, null AS col3, col4
FROM parquetTable;

SELECT col1, col2, col3, NULLIF('','') AS col4
FROM csvTable
UNION ALL
SELECT col2, col2, null AS col3, col4
FROM csvTable;

SELECT col1, col2, col3, NULLIF('','') AS col4
FROM jsonTable
UNION ALL
SELECT col2, col2, null AS col3, col4
FROM jsonTable;

-- Clean-up
DROP VIEW IF EXISTS t1;
DROP VIEW IF EXISTS t2;
DROP VIEW IF EXISTS t3;
DROP VIEW IF EXISTS p1;
DROP VIEW IF EXISTS p2;
DROP VIEW IF EXISTS p3;
DROP TABLE IF EXISTS parquetTable;
DROP TABLE IF EXISTS csvTable;
DROP TABLE IF EXISTS jsonTable;
84 changes: 84 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/union.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,30 @@ struct<>



-- !query
CREATE TABLE parquetTable (col1 INT, col2 INT, col3 INT, col4 INT) USING parquet
-- !query schema
struct<>
-- !query output



-- !query
CREATE TABLE csvTable (col1 INT, col2 INT, col3 INT, col4 INT) USING csv
-- !query schema
struct<>
-- !query output



-- !query
CREATE TABLE jsonTable (col1 INT, col2 INT, col3 INT, col4 INT) USING json
-- !query schema
struct<>
-- !query output



-- !query
SELECT *
FROM (SELECT * FROM t1
Expand Down Expand Up @@ -200,6 +224,42 @@ struct<sum(v):decimal(21,0)>
3


-- !query
SELECT col1, col2, col3, NULLIF('','') AS col4
FROM parquetTable
UNION ALL
SELECT col2, col2, null AS col3, col4
FROM parquetTable
-- !query schema
struct<col1:int,col2:int,col3:int,col4:bigint>
-- !query output



-- !query
SELECT col1, col2, col3, NULLIF('','') AS col4
FROM csvTable
UNION ALL
SELECT col2, col2, null AS col3, col4
FROM csvTable
-- !query schema
struct<col1:int,col2:int,col3:int,col4:bigint>
-- !query output



-- !query
SELECT col1, col2, col3, NULLIF('','') AS col4
FROM jsonTable
UNION ALL
SELECT col2, col2, null AS col3, col4
FROM jsonTable
-- !query schema
struct<col1:int,col2:int,col3:int,col4:bigint>
-- !query output



-- !query
DROP VIEW IF EXISTS t1
-- !query schema
Expand Down Expand Up @@ -246,3 +306,27 @@ DROP VIEW IF EXISTS p3
struct<>
-- !query output



-- !query
DROP TABLE IF EXISTS parquetTable
-- !query schema
struct<>
-- !query output



-- !query
DROP TABLE IF EXISTS csvTable
-- !query schema
struct<>
-- !query output



-- !query
DROP TABLE IF EXISTS jsonTable
-- !query schema
struct<>
-- !query output