From 25c057b889c4ce0058a046ec7b5da3603c8c2bdd Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Sat, 12 Oct 2024 23:28:26 +0200 Subject: [PATCH 01/54] initial --- .../catalyst/parser/DataTypeAstBuilder.scala | 4 +- .../apache/spark/sql/types/StringType.scala | 10 +++- .../sql/catalyst/analysis/Analyzer.scala | 1 + .../sql/catalyst/analysis/CheckAnalysis.scala | 3 + .../analysis/ResolveImplicitStringTypes.scala | 57 +++++++++++++++++++ .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../ResolveImplicitStringTypesSuite.scala | 49 ++++++++++++++++ 7 files changed, 121 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 46fb4a3397c59..e9c75b7ce140d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin} import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, ImplicitStringType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { protected def typedVisit[T](ctx: ParseTree): T = { @@ -74,7 +74,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { case (TIMESTAMP_LTZ, Nil) => TimestampType case (STRING, Nil) => typeCtx.children.asScala.toSeq match { - case Seq(_) => SqlApiConf.get.defaultStringType + case Seq(_) => ImplicitStringType case Seq(_, ctx: CollateClauseContext) => val collationName = visitCollateClause(ctx) val collationId = CollationFactory.collationNameToId(collationName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 29d48e3d1f47f..13d5162106fda 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.{JString, JValue} -import org.apache.spark.annotation.Stable +import org.apache.spark.annotation.{Evolving, Stable} import org.apache.spark.sql.catalyst.util.CollationFactory /** @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory * The id of collation for this StringType. */ @Stable -class StringType private (val collationId: Int) extends AtomicType with Serializable { +class StringType private[sql] (val collationId: Int) extends AtomicType with Serializable { /** * Support for Binary Equality implies that strings are considered equal only if they are byte @@ -106,3 +106,9 @@ case object StringType extends StringType(0) { new StringType(collationId) } } + +/** + * A. + */ +@Evolving +case object ImplicitStringType extends StringType(0) {} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5d41c07b47842..ebc6f8b425878 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -339,6 +339,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveRowLevelCommandAssignments :: MoveParameterizedQueriesDown :: BindParameters :: + ResolveImplicitStringTypes :: typeCoercionRules() ++ Seq( ResolveWithCTE, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4720b9dcdfa13..4f8075f5c4bb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -457,6 +457,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass = "UNBOUND_SQL_PARAMETER", messageParameters = Map("name" -> p.name)) + case e if e.dataType == ImplicitStringType => + throw new RuntimeException("ImplicitStringType is not supported " + e.toJSON) + case _ => }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala new file mode 100644 index 0000000000000..fdf1d4387f42e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{AlterTableCommand, AlterViewAs, AlterViewSchemaBinding, ColumnDefinition, CreateFunction, CreateTable, CreateTableAsSelect, CreateView, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ImplicitStringType, StringType} + +object ResolveImplicitStringTypes extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + plan match { + case _: CreateTable | _: CreateTableAsSelect | _: AlterTableCommand | + _: CreateView | _: AlterViewAs | _: AlterViewSchemaBinding | _: AlterViewSchemaBinding | + _: CreateFunction => + val res = replaceWith(plan, StringType) + res + + case _ => + val res = replaceWith(plan, SQLConf.get.defaultStringType) + res + } + } + + private def replaceWith(plan: LogicalPlan, newType: StringType): LogicalPlan = { + plan resolveOperators { + case l: LogicalPlan => + l transformExpressions { + case columnDef: ColumnDefinition if columnDef.dataType == ImplicitStringType => + columnDef.copy(dataType = newType) + + case Literal(value, ImplicitStringType) => + Literal(value, newType) + + case cast: Cast if cast.dataType == ImplicitStringType => + cast.copy(dataType = newType) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3ecb680cf6427..ae5cf08b11e05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3266,7 +3266,7 @@ class AstBuilder extends DataTypeAstBuilder * Create a String literal expression. */ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { - Literal.create(createString(ctx), conf.defaultStringType) + Literal.create(createString(ctx), ImplicitStringType) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala new file mode 100644 index 0000000000000..3db677a210497 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.collation + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +class ResolveImplicitStringTypesSuite extends QueryTest with SharedSparkSession { + + def withSessionCollation(collation: String)(f: => Unit): Unit = { + sql(s"SET COLLATION $collation") + Utils.tryWithSafeFinally(f) { + sql(s"SET COLLATION UTF8_BINARY") + } + } + + test("initial") { + withTable("tbl") { + sql(s"CREATE TABLE tbl (a STRING) USING parquet") + } + sql(s"SET COLLATION UTF8_LCASE") +// checkAnswer( +// sql(s"SELECT COLLATION('a')"), +// Seq(Row("UTF8_LCASE")) +// ) + + checkAnswer( + sql(s"SELECT 1 WHERE 'a' = cast('A' as STRING)"), + Seq(Row(1)) + ) + } +} From 0c778553ee577452ddcf87dc6012292a3ae90940 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Sun, 13 Oct 2024 00:57:59 +0200 Subject: [PATCH 02/54] remove check in CheckAnalysis.scala --- .../apache/spark/sql/types/StringType.scala | 2 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 3 -- .../analysis/ResolveImplicitStringTypes.scala | 36 +++++++++++++------ 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 13d5162106fda..f50a7488d2b4f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -111,4 +111,4 @@ case object StringType extends StringType(0) { * A. */ @Evolving -case object ImplicitStringType extends StringType(0) {} +object ImplicitStringType extends StringType(0) {} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4f8075f5c4bb9..4720b9dcdfa13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -457,9 +457,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass = "UNBOUND_SQL_PARAMETER", messageParameters = Map("name" -> p.name)) - case e if e.dataType == ImplicitStringType => - throw new RuntimeException("ImplicitStringType is not supported " + e.toJSON) - case _ => }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala index fdf1d4387f42e..fe8023908ae74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala @@ -19,39 +19,53 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AlterTableCommand, AlterViewAs, AlterViewSchemaBinding, ColumnDefinition, CreateFunction, CreateTable, CreateTableAsSelect, CreateView, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterTableCommand, AlterViewAs, AlterViewSchemaBinding, ColumnDefinition, CreateFunction, CreateTable, CreateTableAsSelect, CreateView, LogicalPlan, ReplaceColumns} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ImplicitStringType, StringType} +import org.apache.spark.sql.types.{DataType, ImplicitStringType, StringType} object ResolveImplicitStringTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { plan match { + + // Implicit string type should be resolved to the collation of the object for DDL commands. + // However, this is not implemented yet. So, we will just use UTF8_BINARY for now. case _: CreateTable | _: CreateTableAsSelect | _: AlterTableCommand | + _: AddColumns | _: ReplaceColumns | _: AlterColumn | _: CreateView | _: AlterViewAs | _: AlterViewSchemaBinding | _: AlterViewSchemaBinding | _: CreateFunction => - val res = replaceWith(plan, StringType) - res + plan - case _ => + // Implicit string type should be resolved to the session collation for DML commands. + case _ if SQLConf.get.defaultStringType != StringType => val res = replaceWith(plan, SQLConf.get.defaultStringType) res + + case _ => + plan } } private def replaceWith(plan: LogicalPlan, newType: StringType): LogicalPlan = { plan resolveOperators { - case l: LogicalPlan => - l transformExpressions { - case columnDef: ColumnDefinition if columnDef.dataType == ImplicitStringType => + case operator => + operator transformExpressions { + case columnDef: ColumnDefinition if isImplicitStringType(columnDef.dataType) => columnDef.copy(dataType = newType) + case cast: Cast if isImplicitStringType(cast.dataType) => + cast.copy(dataType = newType) + case Literal(value, ImplicitStringType) => Literal(value, newType) - - case cast: Cast if cast.dataType == ImplicitStringType => - cast.copy(dataType = newType) } } } + + private def isImplicitStringType(dataType: DataType): Boolean = { + dataType match { + case ImplicitStringType => true + case _ => false + } + } } From 09212587901181860f1aaea0af02c0f1e6f81d32 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Sun, 13 Oct 2024 06:58:49 +0200 Subject: [PATCH 03/54] some working version --- .../apache/spark/sql/types/StringType.scala | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 3 + .../analysis/ResolveImplicitStringTypes.scala | 33 ++- .../analysis/TableOutputResolver.scala | 5 + .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../sql/catalyst/rules/RuleExecutor.scala | 25 ++- .../ResolveImplicitStringTypesSuite.scala | 200 ++++++++++++++++-- 8 files changed, 240 insertions(+), 34 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index f50a7488d2b4f..09d4f30ddd8d1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -111,4 +111,4 @@ case object StringType extends StringType(0) { * A. */ @Evolving -object ImplicitStringType extends StringType(0) {} +object ImplicitStringType extends StringType(-2) {} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ebc6f8b425878..b45c6e3e959fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -314,12 +314,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveProcedures :: BindProcedures :: ResolveTableSpec :: + ResolveImplicitStringTypes :: ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: +// ResolveImplicitStringTypes :: ResolveOutputRelation :: new ResolveDataFrameDropColumns(catalogManager) :: new ResolveSetVariable(catalogManager) :: @@ -339,7 +341,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveRowLevelCommandAssignments :: MoveParameterizedQueriesDown :: BindParameters :: - ResolveImplicitStringTypes :: +// ResolveImplicitStringTypes :: typeCoercionRules() ++ Seq( ResolveWithCTE, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4720b9dcdfa13..d193b03e9c061 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -457,6 +457,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass = "UNBOUND_SQL_PARAMETER", messageParameters = Map("name" -> p.name)) + case other if other.dataType == ImplicitStringType => + throw new RuntimeException("Found implicit string type in: " + other.toJSON) + case _ => }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala index fe8023908ae74..d0a342181b18c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterTableCommand, AlterViewAs, AlterViewSchemaBinding, ColumnDefinition, CreateFunction, CreateTable, CreateTableAsSelect, CreateView, LogicalPlan, ReplaceColumns} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, AlterViewSchemaBinding, ColumnDefinition, CreateFunction, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, ImplicitStringType, StringType} @@ -30,11 +30,20 @@ object ResolveImplicitStringTypes extends Rule[LogicalPlan] { // Implicit string type should be resolved to the collation of the object for DDL commands. // However, this is not implemented yet. So, we will just use UTF8_BINARY for now. - case _: CreateTable | _: CreateTableAsSelect | _: AlterTableCommand | - _: AddColumns | _: ReplaceColumns | _: AlterColumn | + case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs | _: AlterViewSchemaBinding | _: AlterViewSchemaBinding | _: CreateFunction => - plan + val res = replaceWith(plan, StringType) + res + + case addCols: AddColumns if hasImplicitStringType(addCols.columnsToAdd) => + addCols.copy(columnsToAdd = replaceColTypes(addCols.columnsToAdd, StringType)) + + case replaceCols: ReplaceColumns if hasImplicitStringType(replaceCols.columnsToAdd) => + replaceCols.copy(columnsToAdd = replaceColTypes(replaceCols.columnsToAdd, StringType)) + + case a: AlterColumn if a.dataType.isDefined && isImplicitStringType(a.dataType.get) => + a.copy(dataType = Some(StringType)) // Implicit string type should be resolved to the session collation for DML commands. case _ if SQLConf.get.defaultStringType != StringType => @@ -42,7 +51,7 @@ object ResolveImplicitStringTypes extends Rule[LogicalPlan] { res case _ => - plan + replaceWith(plan, StringType) } } @@ -62,6 +71,20 @@ object ResolveImplicitStringTypes extends Rule[LogicalPlan] { } } + private def replaceColTypes( + colTypes: Seq[QualifiedColType], + newType: StringType): Seq[QualifiedColType] = { + colTypes.map { + case col if isImplicitStringType(col.dataType) => + col.copy(dataType = newType) + case col => col + } + } + + private def hasImplicitStringType(colTypes: Seq[QualifiedColType]): Boolean = { + colTypes.exists(col => isImplicitStringType(col.dataType)) + } + private def isImplicitStringType(dataType: DataType): Boolean = { dataType match { case ImplicitStringType => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index e2e9bde856781..adf38b219dd9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -564,6 +564,11 @@ object TableOutputResolver extends SQLConfHelper with Logging { conf: SQLConf, colName: String): Expression = { + // this can avoid unneeded casts +// if (expr.dataType == expectedType) { +// return expr +// } + conf.storeAssignmentPolicy match { case StoreAssignmentPolicy.ANSI => val cast = Cast(expr, expectedType, Option(conf.sessionLocalTimeZone), ansiEnabled = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ae5cf08b11e05..d559d8567491b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2048,7 +2048,7 @@ class AstBuilder extends DataTypeAstBuilder } val unresolvedTable = UnresolvedInlineTable(aliases, rows.toSeq) - val table = if (conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + val table = if (false && conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { EvaluateUnresolvedInlineTable.evaluate(unresolvedTable) } else { unresolvedTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index c8b3f224a3129..6bb423d149546 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -233,16 +233,21 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { planChangeLogger.logRule(rule.ruleName, plan, result) // Run the plan changes validation after each rule. if (Utils.isTesting || enableValidation) { - validatePlanChanges(plan, result) match { - case Some(msg) => - throw new SparkException( - errorClass = "PLAN_VALIDATION_FAILED_RULE_IN_BATCH", - messageParameters = Map( - "rule" -> rule.ruleName, - "batch" -> batch.name, - "reason" -> msg), - cause = null) - case _ => + try { + validatePlanChanges(plan, result) match { + case Some(msg) => + throw new SparkException( + errorClass = "PLAN_VALIDATION_FAILED_RULE_IN_BATCH", + messageParameters = Map( + "rule" -> rule.ruleName, + "batch" -> batch.name, + "reason" -> msg), + cause = null) + case _ => + } + } catch { + case a: Throwable => + throw new RuntimeException("asdf") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala index 3db677a210497..e0941202bf947 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala @@ -18,32 +18,200 @@ package org.apache.spark.sql.collation -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.DatasourceV2SQLBase +import org.apache.spark.sql.types.StringType import org.apache.spark.util.Utils -class ResolveImplicitStringTypesSuite extends QueryTest with SharedSparkSession { +class ResolveImplicitStringTypesSuite extends DatasourceV2SQLBase { + + def withSessionCollationAndTable(collation: String, tableName: String)(f: => Unit): Unit = { + withTable(tableName) { + withSessionCollation(collation) { + f + } + } + } def withSessionCollation(collation: String)(f: => Unit): Unit = { sql(s"SET COLLATION $collation") - Utils.tryWithSafeFinally(f) { + Utils.tryWithSafeFinally { + f + } { sql(s"SET COLLATION UTF8_BINARY") } } - test("initial") { - withTable("tbl") { - sql(s"CREATE TABLE tbl (a STRING) USING parquet") + def assertTableColumnCollation(table: String, column: String, expectedCollation: String): Unit = { + val colType = spark.table(table).schema(column).dataType + assert(colType === StringType(expectedCollation)) + } + + // region DDL tests + test("create/alter table") { + val tableName = "testcat.tbl" + withSessionCollationAndTable("UTF8_LCASE", tableName) { + // create table with implicit collation + sql(s"CREATE TABLE $tableName (c1 STRING) USING parquet") + assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") + + // alter table add column with implicit collation + sql(s"ALTER TABLE $tableName ADD COLUMN c2 STRING") + assertTableColumnCollation(tableName, "c2", "UTF8_BINARY") + + // alter table change column with explicit collation +// sql(s"ALTER TABLE $tableName ALTER COLUMN c2 TYPE STRING COLLATE UTF8_LCASE") +// assertTableColumnCollation(tableName, "c2", "UTF8_LCASE") + + // alter table change column with implicit collation +// sql(s"ALTER TABLE $tableName ALTER COLUMN c2 TYPE STRING") + assertTableColumnCollation(tableName, "c2", "UTF8_BINARY") + } + } + + test("create table as select") { + val tableName = "testcat.tbl" + + // literals in select do not pick up session collation + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s"CREATE TABLE $tableName USING parquet AS SELECT 'a' AS c1") + assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") + } + + // cast in select does not pick up session collation + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s"CREATE TABLE $tableName USING parquet AS SELECT cast('a' AS STRING) AS c1") + assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") + } + } + + // TODO: does not work +// test("create/alter view") { +// val viewName = "view_test" +// withSessionCollation("UTF8_LCASE") { +// withTempView(viewName) { +// sql(s"CREATE TEMP VIEW $viewName AS SELECT 'a' AS c1") +// checkAnswer( +// sql(s"SELECT COLLATION(c1) FROM $viewName"), +// Seq(Row("UTF8_BINARY")) +// ) +// } +// } +// } + // endregion + + // region DML tests + test("basic") { + withSessionCollation("UTF8_LCASE") { + + // literal without collation + checkAnswer( + sql("SELECT COLLATION('a')"), + Seq(Row("UTF8_LCASE"))) + + // literal with explicit collation + checkAnswer( + sql("SELECT COLLATION('a' collate unicode)"), + Seq(Row("UNICODE"))) + + // cast is aware of session collation + checkAnswer( + sql("SELECT COLLATION(cast('a' as STRING))"), + Seq(Row("UTF8_LCASE"))) + + // expression in where is aware of session collation + checkAnswer( + sql("SELECT 1 WHERE 'a' = 'A'"), + Seq(Row(1))) + + checkAnswer( + sql("SELECT 1 WHERE 'a' = cast('A' as STRING)"), + Seq(Row(1))) + } + } + + test("having group by is aware of session collation") { + val tableName = "testcat.tbl_grp_by" + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s"CREATE TABLE $tableName (c1 STRING) USING parquet") + sql(s"INSERT INTO $tableName VALUES ('a'), ('A')") + + // having clause uses session collation + checkAnswer( + sql(s"SELECT COUNT(*) FROM $tableName GROUP BY c1 HAVING 'a' = 'A'"), + Seq(Row(1), Row(1))) + + // having clause uses column collation + checkAnswer( + sql(s"SELECT COUNT(*) FROM $tableName GROUP BY c1 HAVING c1 = 'A'"), + Seq(Row(1))) } - sql(s"SET COLLATION UTF8_LCASE") -// checkAnswer( -// sql(s"SELECT COLLATION('a')"), -// Seq(Row("UTF8_LCASE")) -// ) + } - checkAnswer( - sql(s"SELECT 1 WHERE 'a' = cast('A' as STRING)"), - Seq(Row(1)) - ) + test("min/max are aware of session collation") { + // scalastyle:off nonascii + val tableName = "testcat.tbl_min_max" + withSessionCollationAndTable("UNICODE", tableName) { + sql(s"CREATE TABLE $tableName (c1 STRING) USING parquet") + sql(s"INSERT INTO $tableName VALUES ('1'), ('½')") + + checkAnswer( + sql(s"SELECT MIN(c1) FROM $tableName"), + Seq(Row("1"))) + + checkAnswer( + sql(s"SELECT MAX(c1) FROM $tableName"), + Seq(Row("½"))) + } + // scalastyle:on nonascii } + + test("literals in insert inherit session level collation") { + val tableName = "testcat.tbl_insert" + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s"CREATE TABLE $tableName (c1 BOOLEAN) USING parquet") + sql(s"INSERT INTO $tableName VALUES ('a' = 'A')") + + checkAnswer( + sql(s"SELECT COUNT(*) FROM $tableName WHERE c1"), + Seq(Row(1))) + } + } + +// test("update behavior") { +// val tableName = "testcat.tbl_update" +// withTableAndSessionCollation("UTF8_LCASE", tableName) { +// sql(s"CREATE TABLE $tableName (c1 STRING, c2 INT) USING parquet") +// sql(s"INSERT INTO $tableName VALUES ('a', 0), ('A', 0)") +// +// sql(s"UPDATE $tableName SET c1 = 2 WHERE 'a' = 'A'") +// checkAnswer( +// sql(s"SELECT SUM(c2) FROM $tableName"), +// Seq(Row(4))) +// +// sql(s"UPDATE $tableName SET c1 = 2 WHERE c1 = 'A'") +// checkAnswer( +// sql(s"SELECT SUM(c2) FROM $tableName"), +// Seq(Row(2))) +// } +// } + +// test("delete behavior") { +// val tableName = "testcat.tbl_delete" +// withTableAndSessionCollation("UTF8_LCASE", tableName) { +// sql(s"CREATE TABLE $tableName (c1 STRING) USING parquet") +// sql(s"INSERT INTO $tableName VALUES ('a'), ('A')") +// +// sql(s"DELETE FROM $tableName WHERE c1 = 'A'") +// checkAnswer( +// sql(s"SELECT COUNT(*) FROM $tableName"), +// Seq(Row(1))) +// +// sql(s"DELETE FROM $tableName WHERE 'a' = 'A'") +// checkAnswer( +// sql(s"SELECT COUNT(*) FROM $tableName"), +// Seq(Row(0))) +// } +// } + // endregion } From b0a213937fe85628ad42a5ebec6b0e7fcab66383 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Sun, 13 Oct 2024 07:53:31 +0200 Subject: [PATCH 04/54] add fix for eager eval of inline tables --- .../sql/catalyst/parser/AstBuilder.scala | 29 +++++++++++++++++-- .../ResolveImplicitStringTypesSuite.scala | 5 ++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d559d8567491b..a4a84118bec91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2047,8 +2047,9 @@ class AstBuilder extends DataTypeAstBuilder Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } - val unresolvedTable = UnresolvedInlineTable(aliases, rows.toSeq) - val table = if (false && conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + val rowSeq = rows.toSeq + val unresolvedTable = UnresolvedInlineTable(aliases, rowSeq) + val table = if (canEagerlyEvaluateInlineTable(rowSeq)) { EvaluateUnresolvedInlineTable.evaluate(unresolvedTable) } else { unresolvedTable @@ -2056,6 +2057,30 @@ class AstBuilder extends DataTypeAstBuilder table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) } + /** + * Determines if the inline table can be eagerly evaluated. Eager evaluation is not allowed + * if the session-level collation is set and there are string literals present in the expressions, + * because the result may depend on the collation of the input. + */ + private def canEagerlyEvaluateInlineTable(rows: Seq[Seq[Expression]]): Boolean = { + val configSet = conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED) + val usingDefaultStringType = SQLConf.get.defaultStringType == StringType + + configSet && + (usingDefaultStringType || !rows.exists(_.exists(containsStringLiteral))) + } + + private def containsStringLiteral(expression: Expression): Boolean = { + def inner(expr: Expression): Boolean = expr match { + case Literal(_, dataType) => + dataType.existsRecursively(_.isInstanceOf[StringType]) + case _ => + expr.children.exists(inner) + } + + expression.resolved && inner(expression) + } + /** * Create an alias (SubqueryAlias) for a join relation. This is practically the same as * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala index e0941202bf947..3c3e4bcb8a301 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala @@ -175,6 +175,11 @@ class ResolveImplicitStringTypesSuite extends DatasourceV2SQLBase { checkAnswer( sql(s"SELECT COUNT(*) FROM $tableName WHERE c1"), Seq(Row(1))) + + sql(s"INSERT INTO $tableName VALUES (array_contains(array('a'), 'A'))") + checkAnswer( + sql(s"SELECT COUNT(*) FROM $tableName WHERE c1"), + Seq(Row(2))) } } From a18d5b9c43491ffd3aacb54aa9015145656c52c9 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 16 Oct 2024 14:36:28 +0200 Subject: [PATCH 05/54] initial working version with tests --- .../catalyst/parser/DataTypeAstBuilder.scala | 4 +- .../org/apache/spark/sql/types/DataType.scala | 26 +++ .../apache/spark/sql/types/StringType.scala | 26 ++- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 3 - .../analysis/ReplaceDefaultStringType.scala | 84 ++++++++ .../analysis/ResolveImplicitStringTypes.scala | 94 --------- .../analysis/TableOutputResolver.scala | 5 - .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../sql/catalyst/rules/RuleExecutor.scala | 25 +-- .../apache/spark/sql/util/SchemaUtils.scala | 9 +- ....scala => DefaultCollationTestSuite.scala} | 190 +++++++++++------- 12 files changed, 270 insertions(+), 202 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala rename sql/core/src/test/scala/org/apache/spark/sql/collation/{ResolveImplicitStringTypesSuite.scala => DefaultCollationTestSuite.scala} (50%) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index e9c75b7ce140d..7fccdd46575f9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin} import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, ImplicitStringType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DefaultStringType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { protected def typedVisit[T](ctx: ParseTree): T = { @@ -74,7 +74,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { case (TIMESTAMP_LTZ, Nil) => TimestampType case (STRING, Nil) => typeCtx.children.asScala.toSeq match { - case Seq(_) => ImplicitStringType + case Seq(_) => DefaultStringType case Seq(_, ctx: CollateClauseContext) => val collationName = visitCollateClause(ctx) val collationId = CollationFactory.collationNameToId(collationName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 008c9cd07076c..254646cb5f91e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -105,6 +105,32 @@ abstract class DataType extends AbstractDataType { */ private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this) + /** + * Recursively applies the provided partial function `f` to transform this DataType tree. + */ + private[spark] def transformRecursively(f: PartialFunction[DataType, DataType]): DataType = { + this match { + case _ if f.isDefinedAt(this) => + f(this) + + case ArrayType(elementType, containsNull) => + ArrayType(elementType.transformRecursively(f), containsNull) + + case MapType(keyType, valueType, valueContainsNull) => + MapType( + keyType.transformRecursively(f), + valueType.transformRecursively(f), + valueContainsNull) + + case StructType(fields) => + StructType(fields.map { field => + field.copy(dataType = field.dataType.transformRecursively(f)) + }) + + case _ => this + } + } + final override private[sql] def defaultConcreteType: DataType = this override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 09d4f30ddd8d1..2cac13e17a8ed 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -19,18 +19,21 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.{JString, JValue} -import org.apache.spark.annotation.{Evolving, Stable} +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.internal.SqlApiConf /** * The data type representing `String` values. Please use the singleton `DataTypes.StringType`. * * @since 1.3.0 - * @param collationId + * @param _collationId * The id of collation for this StringType. */ @Stable -class StringType private[sql] (val collationId: Int) extends AtomicType with Serializable { +class StringType private[sql] (private val _collationId: Int) extends AtomicType with Serializable { + + def collationId: Int = _collationId /** * Support for Binary Equality implies that strings are considered equal only if they are byte @@ -108,7 +111,18 @@ case object StringType extends StringType(0) { } /** - * A. + * The result type of literals, column definitions without explicit collation, casts to string + * and some expressions that produce strings but whose output type is not based on the types of its + * children. + * Idea is to have this behave like a string with the default collation of the session, but that + * we can still differentiate it from a regular string type, because in some places default string + * is not the one with the session collation (e.g. in DDL commands). */ -@Evolving -object ImplicitStringType extends StringType(-2) {} +private[spark] class DefaultStringType extends StringType(0) { + override def collationId: Int = SqlApiConf.get.defaultStringType.collationId +} + +private[spark] case object DefaultStringType extends DefaultStringType { + def apply(): DefaultStringType = new DefaultStringType() +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b45c6e3e959fb..496414e915413 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -314,14 +314,13 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveProcedures :: BindProcedures :: ResolveTableSpec :: - ResolveImplicitStringTypes :: + ReplaceDefaultStringType :: ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: -// ResolveImplicitStringTypes :: ResolveOutputRelation :: new ResolveDataFrameDropColumns(catalogManager) :: new ResolveSetVariable(catalogManager) :: @@ -341,7 +340,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveRowLevelCommandAssignments :: MoveParameterizedQueriesDown :: BindParameters :: -// ResolveImplicitStringTypes :: typeCoercionRules() ++ Seq( ResolveWithCTE, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d193b03e9c061..4720b9dcdfa13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -457,9 +457,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass = "UNBOUND_SQL_PARAMETER", messageParameters = Map("name" -> p.name)) - case other if other.dataType == ImplicitStringType => - throw new RuntimeException("Found implicit string type in: " + other.toJSON) - case _ => }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala new file mode 100644 index 0000000000000..800513ebd7f63 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, ColumnDefinition, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.{DataType, DefaultStringType, StringType} +import org.apache.spark.sql.util.SchemaUtils + +/** + * Replaces default string types in DDL commands. + * DDL commands should have a default collation based on the object's collation, + * however, this is not implemented yet. So, we will just use UTF8_BINARY for now. + */ +object ReplaceDefaultStringType extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + plan match { + case _: V2CreateTablePlan => + transform(plan, StringType) + + case addCols: AddColumns => + addCols.copy(columnsToAdd = replaceColumnTypes(addCols.columnsToAdd, StringType)) + + case replaceCols: ReplaceColumns => + replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, StringType)) + + case a: AlterColumn if a.dataType.isDefined => + a.copy(dataType = Some(StringType)) + + case _ => + plan + } + } + + private def transform(plan: LogicalPlan, newType: StringType): LogicalPlan = { + plan resolveOperatorsUp { + case operator => + operator transformExpressionsUp { + case columnDef: ColumnDefinition + if SchemaUtils.hasDefaultStringType(columnDef.dataType) => + columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType)) + + case cast: Cast if SchemaUtils.hasDefaultStringType(cast.dataType) => + cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType)) + + case Literal(value, dt) if SchemaUtils.hasDefaultStringType(dt) => + Literal(value, replaceDefaultStringType(dt, newType)) + } + } + } + + private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = { + dataType.transformRecursively { + case st: StringType if st.isInstanceOf[DefaultStringType] => newType + } + } + + private def replaceColumnTypes( + colTypes: Seq[QualifiedColType], + newType: StringType): Seq[QualifiedColType] = { + colTypes.map { + case colWithDefault if SchemaUtils.hasDefaultStringType(colWithDefault.dataType) => + colWithDefault.copy(dataType = newType) + case col => col + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala deleted file mode 100644 index d0a342181b18c..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveImplicitStringTypes.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, AlterViewSchemaBinding, ColumnDefinition, CreateFunction, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, ImplicitStringType, StringType} - -object ResolveImplicitStringTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - plan match { - - // Implicit string type should be resolved to the collation of the object for DDL commands. - // However, this is not implemented yet. So, we will just use UTF8_BINARY for now. - case _: V2CreateTablePlan | - _: CreateView | _: AlterViewAs | _: AlterViewSchemaBinding | _: AlterViewSchemaBinding | - _: CreateFunction => - val res = replaceWith(plan, StringType) - res - - case addCols: AddColumns if hasImplicitStringType(addCols.columnsToAdd) => - addCols.copy(columnsToAdd = replaceColTypes(addCols.columnsToAdd, StringType)) - - case replaceCols: ReplaceColumns if hasImplicitStringType(replaceCols.columnsToAdd) => - replaceCols.copy(columnsToAdd = replaceColTypes(replaceCols.columnsToAdd, StringType)) - - case a: AlterColumn if a.dataType.isDefined && isImplicitStringType(a.dataType.get) => - a.copy(dataType = Some(StringType)) - - // Implicit string type should be resolved to the session collation for DML commands. - case _ if SQLConf.get.defaultStringType != StringType => - val res = replaceWith(plan, SQLConf.get.defaultStringType) - res - - case _ => - replaceWith(plan, StringType) - } - } - - private def replaceWith(plan: LogicalPlan, newType: StringType): LogicalPlan = { - plan resolveOperators { - case operator => - operator transformExpressions { - case columnDef: ColumnDefinition if isImplicitStringType(columnDef.dataType) => - columnDef.copy(dataType = newType) - - case cast: Cast if isImplicitStringType(cast.dataType) => - cast.copy(dataType = newType) - - case Literal(value, ImplicitStringType) => - Literal(value, newType) - } - } - } - - private def replaceColTypes( - colTypes: Seq[QualifiedColType], - newType: StringType): Seq[QualifiedColType] = { - colTypes.map { - case col if isImplicitStringType(col.dataType) => - col.copy(dataType = newType) - case col => col - } - } - - private def hasImplicitStringType(colTypes: Seq[QualifiedColType]): Boolean = { - colTypes.exists(col => isImplicitStringType(col.dataType)) - } - - private def isImplicitStringType(dataType: DataType): Boolean = { - dataType match { - case ImplicitStringType => true - case _ => false - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index adf38b219dd9b..e2e9bde856781 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -564,11 +564,6 @@ object TableOutputResolver extends SQLConfHelper with Logging { conf: SQLConf, colName: String): Expression = { - // this can avoid unneeded casts -// if (expr.dataType == expectedType) { -// return expr -// } - conf.storeAssignmentPolicy match { case StoreAssignmentPolicy.ANSI => val cast = Cast(expr, expectedType, Option(conf.sessionLocalTimeZone), ansiEnabled = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a4a84118bec91..1da5ea75d04ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3291,7 +3291,7 @@ class AstBuilder extends DataTypeAstBuilder * Create a String literal expression. */ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { - Literal.create(createString(ctx), ImplicitStringType) + Literal.create(createString(ctx), DefaultStringType) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 6bb423d149546..c8b3f224a3129 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -233,21 +233,16 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { planChangeLogger.logRule(rule.ruleName, plan, result) // Run the plan changes validation after each rule. if (Utils.isTesting || enableValidation) { - try { - validatePlanChanges(plan, result) match { - case Some(msg) => - throw new SparkException( - errorClass = "PLAN_VALIDATION_FAILED_RULE_IN_BATCH", - messageParameters = Map( - "rule" -> rule.ruleName, - "batch" -> batch.name, - "reason" -> msg), - cause = null) - case _ => - } - } catch { - case a: Throwable => - throw new RuntimeException("asdf") + validatePlanChanges(plan, result) match { + case Some(msg) => + throw new SparkException( + errorClass = "PLAN_VALIDATION_FAILED_RULE_IN_BATCH", + messageParameters = Map( + "rule" -> rule.ruleName, + "batch" -> batch.name, + "reason" -> msg), + cause = null) + case _ => } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index 1e0bac331dc75..277ae0a92bdf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression} import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, DefaultStringType, MapType, StringType, StructField, StructType} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkSchemaUtils @@ -304,6 +304,13 @@ private[spark] object SchemaUtils { } } + /** + * Checks if a given data type has a default string type. + */ + def hasDefaultStringType(dataType: DataType): Boolean = { + dataType.existsRecursively(_.isInstanceOf[DefaultStringType]) + } + /** * Replaces any collated string type with non collated StringType * recursively in the given data type. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala similarity index 50% rename from sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index 3c3e4bcb8a301..04317d406bfd2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/ResolveImplicitStringTypesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.collation import org.apache.spark.sql.Row import org.apache.spark.sql.connector.DatasourceV2SQLBase +import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.StringType -import org.apache.spark.util.Utils -class ResolveImplicitStringTypesSuite extends DatasourceV2SQLBase { +class DefaultCollationTestSuite extends DatasourceV2SQLBase { + + val dataSource: String = "parquet" def withSessionCollationAndTable(collation: String, tableName: String)(f: => Unit): Unit = { withTable(tableName) { @@ -34,11 +36,8 @@ class ResolveImplicitStringTypesSuite extends DatasourceV2SQLBase { } def withSessionCollation(collation: String)(f: => Unit): Unit = { - sql(s"SET COLLATION $collation") - Utils.tryWithSafeFinally { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { f - } { - sql(s"SET COLLATION UTF8_BINARY") } } @@ -48,24 +47,34 @@ class ResolveImplicitStringTypesSuite extends DatasourceV2SQLBase { } // region DDL tests + test("create/alter table") { val tableName = "testcat.tbl" withSessionCollationAndTable("UTF8_LCASE", tableName) { // create table with implicit collation - sql(s"CREATE TABLE $tableName (c1 STRING) USING parquet") + sql(s"CREATE TABLE $tableName (c1 STRING) USING $dataSource") assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") // alter table add column with implicit collation sql(s"ALTER TABLE $tableName ADD COLUMN c2 STRING") assertTableColumnCollation(tableName, "c2", "UTF8_BINARY") - // alter table change column with explicit collation -// sql(s"ALTER TABLE $tableName ALTER COLUMN c2 TYPE STRING COLLATE UTF8_LCASE") -// assertTableColumnCollation(tableName, "c2", "UTF8_LCASE") + // TODO: alter table change column with explicit collation when we add alter support - // alter table change column with implicit collation -// sql(s"ALTER TABLE $tableName ALTER COLUMN c2 TYPE STRING") - assertTableColumnCollation(tableName, "c2", "UTF8_BINARY") + // TODO: alter table change column with default collation when we add alter support + } + } + + test("create table with explicit collation") { + val tableName = "testcat.tbl_explicit_collation" + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s"CREATE TABLE $tableName (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") + assertTableColumnCollation(tableName, "c1", "UTF8_LCASE") + } + + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s"CREATE TABLE $tableName (c1 STRING COLLATE UNICODE) USING $dataSource") + assertTableColumnCollation(tableName, "c1", "UNICODE") } } @@ -74,34 +83,48 @@ class ResolveImplicitStringTypesSuite extends DatasourceV2SQLBase { // literals in select do not pick up session collation withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName USING parquet AS SELECT 'a' AS c1") + sql(s"CREATE TABLE $tableName USING $dataSource AS SELECT 'a' AS c1") assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") } + // literals in inline table do not pick up session collation + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s""" + |CREATE TABLE $tableName USING $dataSource AS + |SELECT c1, c1 = 'A' as c2 FROM VALUES ('a'), ('A') AS vals(c1) + |""".stripMargin) + assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") + checkAnswer( + sql(s"SELECT COUNT(*) FROM $tableName WHERE c2"), + Seq(Row(1))) + } + // cast in select does not pick up session collation withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName USING parquet AS SELECT cast('a' AS STRING) AS c1") + sql(s"CREATE TABLE $tableName USING $dataSource AS SELECT cast('a' AS STRING) AS c1") assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") } } - // TODO: does not work -// test("create/alter view") { -// val viewName = "view_test" -// withSessionCollation("UTF8_LCASE") { -// withTempView(viewName) { -// sql(s"CREATE TEMP VIEW $viewName AS SELECT 'a' AS c1") -// checkAnswer( -// sql(s"SELECT COLLATION(c1) FROM $viewName"), -// Seq(Row("UTF8_BINARY")) -// ) -// } -// } -// } + test("add column") { + val tableName = "testcat.tbl_add_col" + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s"CREATE TABLE $tableName (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") + assertTableColumnCollation(tableName, "c1", "UTF8_LCASE") + + sql(s"ALTER TABLE $tableName ADD COLUMN c2 STRING") + assertTableColumnCollation(tableName, "c2", "UTF8_BINARY") + + sql(s"ALTER TABLE $tableName ADD COLUMN c3 STRING COLLATE UNICODE") + assertTableColumnCollation(tableName, "c3", "UNICODE") + } + } + // endregion // region DML tests - test("basic") { + + test("literals with default collation") { withSessionCollation("UTF8_LCASE") { // literal without collation @@ -109,16 +132,62 @@ class ResolveImplicitStringTypesSuite extends DatasourceV2SQLBase { sql("SELECT COLLATION('a')"), Seq(Row("UTF8_LCASE"))) - // literal with explicit collation + checkAnswer( + sql("SELECT COLLATION(map('a', 'b')['a'])"), + Seq(Row("UTF8_LCASE"))) + + checkAnswer( + sql("SELECT COLLATION(array('a')[0])"), + Seq(Row("UTF8_LCASE"))) + + checkAnswer( + sql("SELECT COLLATION(struct('a' as c)['c'])"), + Seq(Row("UTF8_LCASE"))) + } + } + + test("literals with explicit collation") { + withSessionCollation("UTF8_LCASE") { checkAnswer( sql("SELECT COLLATION('a' collate unicode)"), Seq(Row("UNICODE"))) - // cast is aware of session collation checkAnswer( - sql("SELECT COLLATION(cast('a' as STRING))"), + sql("SELECT COLLATION(map('a', 'b' collate unicode)['a'])"), + Seq(Row("UNICODE"))) + + checkAnswer( + sql("SELECT COLLATION(array('a' collate unicode)[0])"), + Seq(Row("UNICODE"))) + + checkAnswer( + sql("SELECT COLLATION(struct('a' collate unicode as c)['c'])"), + Seq(Row("UNICODE"))) + } + } + + test("cast is aware of session collation") { + withSessionCollation("UTF8_LCASE") { + checkAnswer( + sql("SELECT COLLATION(cast('a' collate unicode as STRING))"), Seq(Row("UTF8_LCASE"))) + checkAnswer( + sql("SELECT COLLATION(cast(map('a', 'b' collate unicode) as MAP)['a'])"), + Seq(Row("UTF8_LCASE"))) + + checkAnswer( + sql("SELECT COLLATION(cast(array('a' collate unicode) as ARRAY)[0])"), + Seq(Row("UTF8_LCASE"))) + + checkAnswer( + sql("SELECT COLLATION(cast(struct('a' collate unicode as c) as STRUCT)['c'])"), + Seq(Row("UTF8_LCASE"))) + } + } + + test("expressions in where are aware of session collation") { + withSessionCollation("UTF8_LCASE") { // expression in where is aware of session collation checkAnswer( sql("SELECT 1 WHERE 'a' = 'A'"), @@ -133,15 +202,15 @@ class ResolveImplicitStringTypesSuite extends DatasourceV2SQLBase { test("having group by is aware of session collation") { val tableName = "testcat.tbl_grp_by" withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName (c1 STRING) USING parquet") + sql(s"CREATE TABLE $tableName (c1 STRING) USING $dataSource") sql(s"INSERT INTO $tableName VALUES ('a'), ('A')") - // having clause uses session collation + // having clause uses session (default) collation checkAnswer( sql(s"SELECT COUNT(*) FROM $tableName GROUP BY c1 HAVING 'a' = 'A'"), Seq(Row(1), Row(1))) - // having clause uses column collation + // having clause uses column (implicit) collation checkAnswer( sql(s"SELECT COUNT(*) FROM $tableName GROUP BY c1 HAVING c1 = 'A'"), Seq(Row(1))) @@ -152,7 +221,7 @@ class ResolveImplicitStringTypesSuite extends DatasourceV2SQLBase { // scalastyle:off nonascii val tableName = "testcat.tbl_min_max" withSessionCollationAndTable("UNICODE", tableName) { - sql(s"CREATE TABLE $tableName (c1 STRING) USING parquet") + sql(s"CREATE TABLE $tableName (c1 STRING) USING $dataSource") sql(s"INSERT INTO $tableName VALUES ('1'), ('½')") checkAnswer( @@ -169,7 +238,7 @@ class ResolveImplicitStringTypesSuite extends DatasourceV2SQLBase { test("literals in insert inherit session level collation") { val tableName = "testcat.tbl_insert" withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName (c1 BOOLEAN) USING parquet") + sql(s"CREATE TABLE $tableName (c1 BOOLEAN) USING $dataSource") sql(s"INSERT INTO $tableName VALUES ('a' = 'A')") checkAnswer( @@ -183,40 +252,17 @@ class ResolveImplicitStringTypesSuite extends DatasourceV2SQLBase { } } -// test("update behavior") { -// val tableName = "testcat.tbl_update" -// withTableAndSessionCollation("UTF8_LCASE", tableName) { -// sql(s"CREATE TABLE $tableName (c1 STRING, c2 INT) USING parquet") -// sql(s"INSERT INTO $tableName VALUES ('a', 0), ('A', 0)") -// -// sql(s"UPDATE $tableName SET c1 = 2 WHERE 'a' = 'A'") -// checkAnswer( -// sql(s"SELECT SUM(c2) FROM $tableName"), -// Seq(Row(4))) -// -// sql(s"UPDATE $tableName SET c1 = 2 WHERE c1 = 'A'") -// checkAnswer( -// sql(s"SELECT SUM(c2) FROM $tableName"), -// Seq(Row(2))) -// } -// } - -// test("delete behavior") { -// val tableName = "testcat.tbl_delete" -// withTableAndSessionCollation("UTF8_LCASE", tableName) { -// sql(s"CREATE TABLE $tableName (c1 STRING) USING parquet") -// sql(s"INSERT INTO $tableName VALUES ('a'), ('A')") -// -// sql(s"DELETE FROM $tableName WHERE c1 = 'A'") -// checkAnswer( -// sql(s"SELECT COUNT(*) FROM $tableName"), -// Seq(Row(1))) -// -// sql(s"DELETE FROM $tableName WHERE 'a' = 'A'") -// checkAnswer( -// sql(s"SELECT COUNT(*) FROM $tableName"), -// Seq(Row(0))) -// } -// } + test("delete behavior") { + val tableName = "testcat.tbl_delete" + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s"CREATE TABLE $tableName (c1 STRING) USING $dataSource") + sql(s"INSERT INTO $tableName VALUES ('a'), ('A')") + + sql(s"DELETE FROM $tableName WHERE 'a' = 'A'") + checkAnswer( + sql(s"SELECT COUNT(*) FROM $tableName"), + Seq(Row(0))) + } + } // endregion } From 14130fa9feb4564400cc1f56623813c4a176667e Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 16 Oct 2024 19:22:32 +0200 Subject: [PATCH 06/54] change collation id for default type --- .../sql/catalyst/util/CollationFactory.java | 4 +++ .../apache/spark/sql/types/StringType.scala | 2 +- .../analysis/ReplaceDefaultStringType.scala | 10 +++---- .../sql/catalyst/parser/AstBuilder.scala | 4 +-- .../collation/DefaultCollationTestSuite.scala | 29 ++++++++++++------- 5 files changed, 30 insertions(+), 19 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 113c5f866fd88..cc8b0be911b16 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -310,6 +310,8 @@ protected enum SpaceTrimming { private static final int INDETERMINATE_COLLATION_ID = -1; + private static final int DEFAULT_COLLATION_ID = -2; + /** * Thread-safe cache mapping collation IDs to corresponding `Collation` instances. * We add entries to this cache lazily as new `Collation` instances are requested. @@ -1041,6 +1043,8 @@ public CollationIdentifier identifier() { Collation.CollationSpecICU.UNICODE_CI_COLLATION_ID; public static final int INDETERMINATE_COLLATION_ID = Collation.CollationSpec.INDETERMINATE_COLLATION_ID; + public static final int DEFAULT_COLLATION_ID = + Collation.CollationSpec.DEFAULT_COLLATION_ID; /** * Returns a StringSearch object for the given pattern and target strings, under collation diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 2cac13e17a8ed..319cc134090c7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -101,7 +101,7 @@ class StringType private[sql] (private val _collationId: Int) extends AtomicType * @since 1.3.0 */ @Stable -case object StringType extends StringType(0) { +case object StringType extends StringType(CollationFactory.DEFAULT_COLLATION_ID) { private[spark] def apply(collationId: Int): StringType = new StringType(collationId) def apply(collation: String): StringType = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala index 800513ebd7f63..0204efd8b2301 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.util.SchemaUtils */ object ReplaceDefaultStringType extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - plan match { + plan resolveOperatorsUp { case _: V2CreateTablePlan => transform(plan, StringType) @@ -41,11 +41,9 @@ object ReplaceDefaultStringType extends Rule[LogicalPlan] { case replaceCols: ReplaceColumns => replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, StringType)) - case a: AlterColumn if a.dataType.isDefined => - a.copy(dataType = Some(StringType)) - - case _ => - plan + case a: AlterColumn + if a.dataType.isDefined && SchemaUtils.hasDefaultStringType(a.dataType.get) => + a.copy(dataType = Some(replaceDefaultStringType(a.dataType.get, StringType))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 1da5ea75d04ec..70d2cca58cddc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2064,10 +2064,10 @@ class AstBuilder extends DataTypeAstBuilder */ private def canEagerlyEvaluateInlineTable(rows: Seq[Seq[Expression]]): Boolean = { val configSet = conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED) - val usingDefaultStringType = SQLConf.get.defaultStringType == StringType + val sessionCollationUnchanged = DefaultStringType == StringType configSet && - (usingDefaultStringType || !rows.exists(_.exists(containsStringLiteral))) + (sessionCollationUnchanged || !rows.exists(_.exists(containsStringLiteral))) } private def containsStringLiteral(expression: Expression): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index 04317d406bfd2..bc7f5b0dd72dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.StringType class DefaultCollationTestSuite extends DatasourceV2SQLBase { - + val dataSource: String = "parquet" def withSessionCollationAndTable(collation: String, tableName: String)(f: => Unit): Unit = { @@ -49,7 +49,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { // region DDL tests test("create/alter table") { - val tableName = "testcat.tbl" + val tableName = "tbl" withSessionCollationAndTable("UTF8_LCASE", tableName) { // create table with implicit collation sql(s"CREATE TABLE $tableName (c1 STRING) USING $dataSource") @@ -59,14 +59,16 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { sql(s"ALTER TABLE $tableName ADD COLUMN c2 STRING") assertTableColumnCollation(tableName, "c2", "UTF8_BINARY") - // TODO: alter table change column with explicit collation when we add alter support + sql(s"ALTER TABLE $tableName ALTER COLUMN c2 TYPE STRING COLLATE UNICODE") + assertTableColumnCollation(tableName, "c2", "UNICODE") - // TODO: alter table change column with default collation when we add alter support + sql(s"ALTER TABLE $tableName ALTER COLUMN c2 TYPE STRING") + assertTableColumnCollation(tableName, "c2", "UTF8_BINARY") } } test("create table with explicit collation") { - val tableName = "testcat.tbl_explicit_collation" + val tableName = "tbl_explicit_collation" withSessionCollationAndTable("UTF8_LCASE", tableName) { sql(s"CREATE TABLE $tableName (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") assertTableColumnCollation(tableName, "c1", "UTF8_LCASE") @@ -79,7 +81,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { } test("create table as select") { - val tableName = "testcat.tbl" + val tableName = "tbl" // literals in select do not pick up session collation withSessionCollationAndTable("UTF8_LCASE", tableName) { @@ -87,6 +89,13 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") } + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s"CREATE TABLE $tableName USING $dataSource AS SELECT 'a' = 'A' AS c1") + checkAnswer( + sql(s"SELECT COUNT(*) FROM $tableName WHERE c1"), + Seq(Row(1))) + } + // literals in inline table do not pick up session collation withSessionCollationAndTable("UTF8_LCASE", tableName) { sql(s""" @@ -107,7 +116,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { } test("add column") { - val tableName = "testcat.tbl_add_col" + val tableName = "tbl_add_col" withSessionCollationAndTable("UTF8_LCASE", tableName) { sql(s"CREATE TABLE $tableName (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") assertTableColumnCollation(tableName, "c1", "UTF8_LCASE") @@ -200,7 +209,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { } test("having group by is aware of session collation") { - val tableName = "testcat.tbl_grp_by" + val tableName = "tbl_grp_by" withSessionCollationAndTable("UTF8_LCASE", tableName) { sql(s"CREATE TABLE $tableName (c1 STRING) USING $dataSource") sql(s"INSERT INTO $tableName VALUES ('a'), ('A')") @@ -219,7 +228,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { test("min/max are aware of session collation") { // scalastyle:off nonascii - val tableName = "testcat.tbl_min_max" + val tableName = "tbl_min_max" withSessionCollationAndTable("UNICODE", tableName) { sql(s"CREATE TABLE $tableName (c1 STRING) USING $dataSource") sql(s"INSERT INTO $tableName VALUES ('1'), ('½')") @@ -236,7 +245,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { } test("literals in insert inherit session level collation") { - val tableName = "testcat.tbl_insert" + val tableName = "tbl_insert" withSessionCollationAndTable("UTF8_LCASE", tableName) { sql(s"CREATE TABLE $tableName (c1 BOOLEAN) USING $dataSource") sql(s"INSERT INTO $tableName VALUES ('a' = 'A')") From 67e8dbb28a620b956ee3eccf4ffa836528c11013 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 16 Oct 2024 19:43:23 +0200 Subject: [PATCH 07/54] fix typo --- .../main/scala/org/apache/spark/sql/types/StringType.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 319cc134090c7..a53e96e15a2e3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -101,7 +101,7 @@ class StringType private[sql] (private val _collationId: Int) extends AtomicType * @since 1.3.0 */ @Stable -case object StringType extends StringType(CollationFactory.DEFAULT_COLLATION_ID) { +case object StringType extends StringType(0) { private[spark] def apply(collationId: Int): StringType = new StringType(collationId) def apply(collation: String): StringType = { @@ -118,7 +118,7 @@ case object StringType extends StringType(CollationFactory.DEFAULT_COLLATION_ID) * we can still differentiate it from a regular string type, because in some places default string * is not the one with the session collation (e.g. in DDL commands). */ -private[spark] class DefaultStringType extends StringType(0) { +private[spark] class DefaultStringType extends StringType(CollationFactory.DEFAULT_COLLATION_ID) { override def collationId: Int = SqlApiConf.get.defaultStringType.collationId } From 172d58fcbc3ff76dfb48ded71a64a8b2f45b3c00 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 17 Oct 2024 19:51:18 +0200 Subject: [PATCH 08/54] fix failing and add new tests --- .../catalyst/parser/DataTypeAstBuilder.scala | 2 +- .../apache/spark/sql/types/StringType.scala | 17 +++---- .../analysis/CollationTypeCasts.scala | 21 ++++++--- .../analysis/ReplaceDefaultStringType.scala | 44 ++++++++++++------- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../spark/sql/CollationSQLRegexpSuite.scala | 6 ++- .../org/apache/spark/sql/CollationSuite.scala | 23 ---------- .../collation/DefaultCollationTestSuite.scala | 43 +++++++++++++++--- 8 files changed, 94 insertions(+), 66 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 7fccdd46575f9..949f7eeae6217 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -74,7 +74,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { case (TIMESTAMP_LTZ, Nil) => TimestampType case (STRING, Nil) => typeCtx.children.asScala.toSeq match { - case Seq(_) => DefaultStringType + case Seq(_) => DefaultStringType() case Seq(_, ctx: CollateClauseContext) => val collationName = visitCollateClause(ctx) val collationId = CollationFactory.collationNameToId(collationName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index a53e96e15a2e3..fab2dc2f991d0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -27,13 +27,11 @@ import org.apache.spark.sql.internal.SqlApiConf * The data type representing `String` values. Please use the singleton `DataTypes.StringType`. * * @since 1.3.0 - * @param _collationId + * @param collationId * The id of collation for this StringType. */ @Stable -class StringType private[sql] (private val _collationId: Int) extends AtomicType with Serializable { - - def collationId: Int = _collationId +class StringType private[sql] (val collationId: Int) extends AtomicType with Serializable { /** * Support for Binary Equality implies that strings are considered equal only if they are byte @@ -118,11 +116,10 @@ case object StringType extends StringType(0) { * we can still differentiate it from a regular string type, because in some places default string * is not the one with the session collation (e.g. in DDL commands). */ -private[spark] class DefaultStringType extends StringType(CollationFactory.DEFAULT_COLLATION_ID) { - override def collationId: Int = SqlApiConf.get.defaultStringType.collationId -} +private[spark] class DefaultStringType private (collationId: Int) extends StringType(collationId) {} -private[spark] case object DefaultStringType extends DefaultStringType { - def apply(): DefaultStringType = new DefaultStringType() +private[spark] object DefaultStringType { + def apply(): DefaultStringType = { + new DefaultStringType(SqlApiConf.get.defaultStringType.collationId) + } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index ced54e590ecc2..06e12b0083a5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -24,7 +24,6 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType} object CollationTypeCasts extends TypeCoercionRule { @@ -161,6 +160,8 @@ object CollationTypeCasts extends TypeCoercionRule { * complex DataTypes with collated StringTypes (e.g. ArrayType) */ def getOutputCollation(expr: Seq[Expression]): StringType = { + require(expr.nonEmpty, "Cannot determine collation for empty expressions") + val explicitTypes = expr.filter { case _: Collate => true case _ => false @@ -190,11 +191,19 @@ object CollationTypeCasts extends TypeCoercionRule { .map(extractStringType(_).collationId) .distinct - if (implicitTypes.length > 1) { - throw QueryCompilationErrors.implicitCollationMismatchError() - } - else { - implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType) + implicitTypes.length match { + case 1 => + StringType(implicitTypes.head) + + case size if size > 1 => + throw QueryCompilationErrors.implicitCollationMismatchError() + + case _ => + // If there are no implicit collations then all expressions must have + // the default collation so we can just return either one + require(expr.forall(_.dataType.sameType(expr.head.dataType))) + + expr.head.dataType.asInstanceOf[StringType] } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala index 0204efd8b2301..7999ed7969c6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, ColumnDefinition, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.{DataType, DefaultStringType, StringType} @@ -32,8 +32,8 @@ import org.apache.spark.sql.util.SchemaUtils object ReplaceDefaultStringType extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { plan resolveOperatorsUp { - case _: V2CreateTablePlan => - transform(plan, StringType) + case create: V2CreateTablePlan => + transformPlan(create, StringType) case addCols: AddColumns => addCols.copy(columnsToAdd = replaceColumnTypes(addCols.columnsToAdd, StringType)) @@ -47,26 +47,34 @@ object ReplaceDefaultStringType extends Rule[LogicalPlan] { } } - private def transform(plan: LogicalPlan, newType: StringType): LogicalPlan = { - plan resolveOperatorsUp { - case operator => - operator transformExpressionsUp { - case columnDef: ColumnDefinition - if SchemaUtils.hasDefaultStringType(columnDef.dataType) => - columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType)) + private def transformPlan(plan: LogicalPlan, newType: StringType): LogicalPlan = { + plan resolveOperatorsUp { operator => + operator.transformExpressionsUp { expression => + transformExpression(expression, newType) + } + } + } - case cast: Cast if SchemaUtils.hasDefaultStringType(cast.dataType) => - cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType)) + private def transformExpression(expression: Expression, newType: StringType): Expression = { + expression match { + case columnDef: ColumnDefinition + if SchemaUtils.hasDefaultStringType(columnDef.dataType) => + columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType)) - case Literal(value, dt) if SchemaUtils.hasDefaultStringType(dt) => - Literal(value, replaceDefaultStringType(dt, newType)) - } + case cast: Cast if SchemaUtils.hasDefaultStringType(cast.dataType) => + cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType)) + + case Literal(value, dt) if SchemaUtils.hasDefaultStringType(dt) => + val replaced = replaceDefaultStringType(dt, newType) + Literal(value, replaced) + + case other => other } } private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = { dataType.transformRecursively { - case st: StringType if st.isInstanceOf[DefaultStringType] => newType + case _: DefaultStringType => newType } } @@ -75,7 +83,9 @@ object ReplaceDefaultStringType extends Rule[LogicalPlan] { newType: StringType): Seq[QualifiedColType] = { colTypes.map { case colWithDefault if SchemaUtils.hasDefaultStringType(colWithDefault.dataType) => - colWithDefault.copy(dataType = newType) + val replaced = replaceDefaultStringType(colWithDefault.dataType, newType) + colWithDefault.copy(dataType = replaced) + case col => col } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 70d2cca58cddc..cf6b6a26f2df2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2064,7 +2064,7 @@ class AstBuilder extends DataTypeAstBuilder */ private def canEagerlyEvaluateInlineTable(rows: Seq[Seq[Expression]]): Boolean = { val configSet = conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED) - val sessionCollationUnchanged = DefaultStringType == StringType + val sessionCollationUnchanged = DefaultStringType() == StringType configSet && (sessionCollationUnchanged || !rows.exists(_.exists(containsStringLiteral))) @@ -3291,7 +3291,7 @@ class AstBuilder extends DataTypeAstBuilder * Create a String literal expression. */ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { - Literal.create(createString(ctx), DefaultStringType) + Literal.create(createString(ctx), DefaultStringType()) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala index 87dbbc65a3936..0de939aa62df3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala @@ -112,7 +112,11 @@ class CollationSQLRegexpSuite val tableNameLcase = "T_LCASE" withTable(tableNameLcase) { withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_LCASE") { - sql(s"CREATE TABLE IF NOT EXISTS $tableNameLcase(c STRING) using PARQUET") + sql(s""" + |CREATE TABLE IF NOT EXISTS $tableNameLcase( + | c STRING COLLATE UTF8_LCASE + |) using PARQUET + |""".stripMargin) sql(s"INSERT INTO $tableNameLcase(c) VALUES('ABC')") checkAnswer(sql(s"select c like 'ab%' FROM $tableNameLcase"), Row(true)) checkAnswer(sql(s"select c like '%bc' FROM $tableNameLcase"), Row(true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index b19af542dabf2..97ff85f6a682f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -977,29 +977,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("SPARK-47431: Default collation set to UNICODE, column type test") { - withTable("t") { - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { - sql(s"CREATE TABLE t(c1 STRING) USING PARQUET") - sql(s"INSERT INTO t VALUES ('a')") - checkAnswer(sql(s"SELECT collation(c1) FROM t"), Seq(Row("UNICODE"))) - } - } - } - - test("SPARK-47431: Create table with UTF8_BINARY, make sure collation persists on read") { - withTable("t") { - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_BINARY") { - sql("CREATE TABLE t(c1 STRING) USING PARQUET") - sql("INSERT INTO t VALUES ('a')") - checkAnswer(sql("SELECT collation(c1) FROM t"), Seq(Row("UTF8_BINARY"))) - } - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { - checkAnswer(sql("SELECT collation(c1) FROM t"), Seq(Row("UTF8_BINARY"))) - } - } - } - test("Create dataframe with non utf8 binary collation") { val schema = StructType(Seq(StructField("Name", StringType("UNICODE_CI")))) val data = Seq(Row("Alice"), Row("Bob"), Row("bob")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index bc7f5b0dd72dd..a259e61ac0e1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -85,15 +85,22 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { // literals in select do not pick up session collation withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName USING $dataSource AS SELECT 'a' AS c1") + sql(s""" + |CREATE TABLE $tableName USING $dataSource AS SELECT + | 'a' AS c1, + | 'a' || 'a' AS c2, + | SUBSTRING('a', 1, 1) AS c3, + | SUBSTRING(SUBSTRING('ab', 1, 1), 1, 1) AS c4, + | 'a' = 'A' AS truthy + |""".stripMargin) assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") - } + assertTableColumnCollation(tableName, "c2", "UTF8_BINARY") + assertTableColumnCollation(tableName, "c3", "UTF8_BINARY") + assertTableColumnCollation(tableName, "c4", "UTF8_BINARY") - withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName USING $dataSource AS SELECT 'a' = 'A' AS c1") checkAnswer( - sql(s"SELECT COUNT(*) FROM $tableName WHERE c1"), - Seq(Row(1))) + sql(s"SELECT COUNT(*) FROM $tableName WHERE truthy"), + Seq(Row(0))) } // literals in inline table do not pick up session collation @@ -115,6 +122,30 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { } } + test("ctas with complex types") { + val tableName = "tbl_complex" + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s""" + |CREATE TABLE $tableName USING $dataSource AS + |SELECT + | struct('a') AS c1, + | map('a', 'b') AS c2, + | array('a') AS c3 + |""".stripMargin) + + checkAnswer( + sql(s"SELECT COLLATION(c1.col1) FROM $tableName"), + Seq(Row("UTF8_BINARY"))) + checkAnswer( + // TODO: other PR is supposed to fix explicit collation here + sql(s"SELECT COLLATION(c2['a' collate UTF8_BINARY]) FROM $tableName"), + Seq(Row("UTF8_BINARY"))) + checkAnswer( + sql(s"SELECT COLLATION(c3[0]) FROM $tableName"), + Seq(Row("UTF8_BINARY"))) + } + } + test("add column") { val tableName = "tbl_add_col" withSessionCollationAndTable("UTF8_LCASE", tableName) { From 7ac56543f21f92a725b8a1e654f49e7789eed4c2 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Fri, 18 Oct 2024 15:57:19 +0200 Subject: [PATCH 09/54] formatting --- .../apache/spark/sql/types/StringType.scala | 14 ++-- .../analysis/ReplaceDefaultStringType.scala | 10 +-- .../catalyst/plans/logical/v2Commands.scala | 3 +- .../collation/DefaultCollationTestSuite.scala | 74 ++++++------------- 4 files changed, 34 insertions(+), 67 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index de21b59000a6b..e4f3073981e84 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -112,14 +112,14 @@ case object StringType extends StringType(0) { } /** - * The result type of literals, column definitions without explicit collation, casts to string - * and some expressions that produce strings but whose output type is not based on the types of its - * children. - * Idea is to have this behave like a string with the default collation of the session, but that - * we can still differentiate it from a regular string type, because in some places default string - * is not the one with the session collation (e.g. in DDL commands). + * The result type of literals, column definitions without explicit collation, casts to string and + * some expressions that produce strings but whose output type is not based on the types of its + * children. Idea is to have this behave like a string with the default collation of the session, + * but that we can still differentiate it from a regular string type, because in some places + * default string is not the one with the session collation (e.g. in DDL commands). */ -private[spark] class DefaultStringType private (collationId: Int) extends StringType(collationId) {} +private[spark] class DefaultStringType private (collationId: Int) + extends StringType(collationId) {} private[spark] object DefaultStringType { def apply(): DefaultStringType = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala index 9a61c6bdb6c3c..cc8de23c416f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala @@ -15,7 +15,6 @@ * limitations under the License. */ - package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} @@ -25,9 +24,9 @@ import org.apache.spark.sql.types.{DataType, DefaultStringType, StringType} import org.apache.spark.sql.util.SchemaUtils /** - * Replaces default string types in DDL commands. - * DDL commands should have a default collation based on the object's collation, - * however, this is not implemented yet. So, we will just use UTF8_BINARY for now. + * Replaces default string types in DDL commands. DDL commands should have a default collation + * based on the object's collation, however, this is not implemented yet. So, we will just use + * UTF8_BINARY for now. */ object ReplaceDefaultStringType extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { @@ -60,8 +59,7 @@ object ReplaceDefaultStringType extends Rule[LogicalPlan] { private def transformExpression(expression: Expression, newType: StringType): Expression = { expression match { - case columnDef: ColumnDefinition - if SchemaUtils.hasDefaultStringType(columnDef.dataType) => + case columnDef: ColumnDefinition if SchemaUtils.hasDefaultStringType(columnDef.dataType) => columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType)) case cast: Cast if SchemaUtils.hasDefaultStringType(cast.dataType) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 8e9a7487c988f..f6eba1552cfff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -89,7 +89,8 @@ trait V2PartitionCommand extends UnaryCommand { } /** - * Trait to + * Trait to add to v1 ddl commands so that we can still catch those plan nodes + * in the catalyst analyzer rules. */ trait V1DDLCommand {} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index a259e61ac0e1f..edc8ae9482f78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -15,7 +15,6 @@ * limitations under the License. */ - package org.apache.spark.sql.collation import org.apache.spark.sql.Row @@ -41,7 +40,10 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { } } - def assertTableColumnCollation(table: String, column: String, expectedCollation: String): Unit = { + def assertTableColumnCollation( + table: String, + column: String, + expectedCollation: String): Unit = { val colType = spark.table(table).schema(column).dataType assert(colType === StringType(expectedCollation)) } @@ -98,9 +100,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { assertTableColumnCollation(tableName, "c3", "UTF8_BINARY") assertTableColumnCollation(tableName, "c4", "UTF8_BINARY") - checkAnswer( - sql(s"SELECT COUNT(*) FROM $tableName WHERE truthy"), - Seq(Row(0))) + checkAnswer(sql(s"SELECT COUNT(*) FROM $tableName WHERE truthy"), Seq(Row(0))) } // literals in inline table do not pick up session collation @@ -110,9 +110,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { |SELECT c1, c1 = 'A' as c2 FROM VALUES ('a'), ('A') AS vals(c1) |""".stripMargin) assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") - checkAnswer( - sql(s"SELECT COUNT(*) FROM $tableName WHERE c2"), - Seq(Row(1))) + checkAnswer(sql(s"SELECT COUNT(*) FROM $tableName WHERE c2"), Seq(Row(1))) } // cast in select does not pick up session collation @@ -133,16 +131,12 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { | array('a') AS c3 |""".stripMargin) - checkAnswer( - sql(s"SELECT COLLATION(c1.col1) FROM $tableName"), - Seq(Row("UTF8_BINARY"))) + checkAnswer(sql(s"SELECT COLLATION(c1.col1) FROM $tableName"), Seq(Row("UTF8_BINARY"))) checkAnswer( // TODO: other PR is supposed to fix explicit collation here sql(s"SELECT COLLATION(c2['a' collate UTF8_BINARY]) FROM $tableName"), Seq(Row("UTF8_BINARY"))) - checkAnswer( - sql(s"SELECT COLLATION(c3[0]) FROM $tableName"), - Seq(Row("UTF8_BINARY"))) + checkAnswer(sql(s"SELECT COLLATION(c3[0]) FROM $tableName"), Seq(Row("UTF8_BINARY"))) } } @@ -168,37 +162,25 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { withSessionCollation("UTF8_LCASE") { // literal without collation - checkAnswer( - sql("SELECT COLLATION('a')"), - Seq(Row("UTF8_LCASE"))) + checkAnswer(sql("SELECT COLLATION('a')"), Seq(Row("UTF8_LCASE"))) - checkAnswer( - sql("SELECT COLLATION(map('a', 'b')['a'])"), - Seq(Row("UTF8_LCASE"))) + checkAnswer(sql("SELECT COLLATION(map('a', 'b')['a'])"), Seq(Row("UTF8_LCASE"))) - checkAnswer( - sql("SELECT COLLATION(array('a')[0])"), - Seq(Row("UTF8_LCASE"))) + checkAnswer(sql("SELECT COLLATION(array('a')[0])"), Seq(Row("UTF8_LCASE"))) - checkAnswer( - sql("SELECT COLLATION(struct('a' as c)['c'])"), - Seq(Row("UTF8_LCASE"))) + checkAnswer(sql("SELECT COLLATION(struct('a' as c)['c'])"), Seq(Row("UTF8_LCASE"))) } } test("literals with explicit collation") { withSessionCollation("UTF8_LCASE") { - checkAnswer( - sql("SELECT COLLATION('a' collate unicode)"), - Seq(Row("UNICODE"))) + checkAnswer(sql("SELECT COLLATION('a' collate unicode)"), Seq(Row("UNICODE"))) checkAnswer( sql("SELECT COLLATION(map('a', 'b' collate unicode)['a'])"), Seq(Row("UNICODE"))) - checkAnswer( - sql("SELECT COLLATION(array('a' collate unicode)[0])"), - Seq(Row("UNICODE"))) + checkAnswer(sql("SELECT COLLATION(array('a' collate unicode)[0])"), Seq(Row("UNICODE"))) checkAnswer( sql("SELECT COLLATION(struct('a' collate unicode as c)['c'])"), @@ -229,13 +211,9 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { test("expressions in where are aware of session collation") { withSessionCollation("UTF8_LCASE") { // expression in where is aware of session collation - checkAnswer( - sql("SELECT 1 WHERE 'a' = 'A'"), - Seq(Row(1))) + checkAnswer(sql("SELECT 1 WHERE 'a' = 'A'"), Seq(Row(1))) - checkAnswer( - sql("SELECT 1 WHERE 'a' = cast('A' as STRING)"), - Seq(Row(1))) + checkAnswer(sql("SELECT 1 WHERE 'a' = cast('A' as STRING)"), Seq(Row(1))) } } @@ -264,13 +242,9 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { sql(s"CREATE TABLE $tableName (c1 STRING) USING $dataSource") sql(s"INSERT INTO $tableName VALUES ('1'), ('½')") - checkAnswer( - sql(s"SELECT MIN(c1) FROM $tableName"), - Seq(Row("1"))) + checkAnswer(sql(s"SELECT MIN(c1) FROM $tableName"), Seq(Row("1"))) - checkAnswer( - sql(s"SELECT MAX(c1) FROM $tableName"), - Seq(Row("½"))) + checkAnswer(sql(s"SELECT MAX(c1) FROM $tableName"), Seq(Row("½"))) } // scalastyle:on nonascii } @@ -281,14 +255,10 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { sql(s"CREATE TABLE $tableName (c1 BOOLEAN) USING $dataSource") sql(s"INSERT INTO $tableName VALUES ('a' = 'A')") - checkAnswer( - sql(s"SELECT COUNT(*) FROM $tableName WHERE c1"), - Seq(Row(1))) + checkAnswer(sql(s"SELECT COUNT(*) FROM $tableName WHERE c1"), Seq(Row(1))) sql(s"INSERT INTO $tableName VALUES (array_contains(array('a'), 'A'))") - checkAnswer( - sql(s"SELECT COUNT(*) FROM $tableName WHERE c1"), - Seq(Row(2))) + checkAnswer(sql(s"SELECT COUNT(*) FROM $tableName WHERE c1"), Seq(Row(2))) } } @@ -299,9 +269,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { sql(s"INSERT INTO $tableName VALUES ('a'), ('A')") sql(s"DELETE FROM $tableName WHERE 'a' = 'A'") - checkAnswer( - sql(s"SELECT COUNT(*) FROM $tableName"), - Seq(Row(0))) + checkAnswer(sql(s"SELECT COUNT(*) FROM $tableName"), Seq(Row(0))) } } // endregion From 82dcbf46ddd353c334a5273d8e2d7b2299ff6ab4 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Fri, 18 Oct 2024 18:13:37 +0200 Subject: [PATCH 10/54] fix toString method --- .../main/scala/org/apache/spark/sql/types/StringType.scala | 4 ++++ .../spark/sql/catalyst/analysis/CollationTypeCasts.scala | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index e4f3073981e84..27836e128cee3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -78,6 +78,10 @@ class StringType private[sql] (val collationId: Int) extends AtomicType with Ser if (isUTF8BinaryCollation) "string" else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}" + override def toString: String = + if (isUTF8BinaryCollation) "StringType" + else s"StringType($collationId)" + // Due to backwards compatibility and compatibility with other readers // all string types are serialized in json as regular strings and // the collation information is written to struct field metadata diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index ea2c68237c570..deafacd890de9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -159,8 +159,8 @@ object CollationTypeCasts extends TypeCoercionRule { * any expressions, but will only be affected by collated StringTypes or * complex DataTypes with collated StringTypes (e.g. ArrayType) */ - def getOutputCollation(expressions: Seq[Expression]): StringType = { - val explicitTypes = expressions.filter { + def getOutputCollation(expr: Seq[Expression]): StringType = { + val explicitTypes = expr.filter { case _: Collate => true case _ => false } @@ -178,7 +178,7 @@ object CollationTypeCasts extends TypeCoercionRule { ) // Only implicit or default collations present case 0 => - val implicitTypes = expressions.filter { + val implicitTypes = expr.filter { case Literal(_, _: StringType) => false case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty => cast.child.dataType.isInstanceOf[StringType] From 014c855795b28864162fe4ef7e67691e2432373f Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Sun, 20 Oct 2024 01:39:04 +0200 Subject: [PATCH 11/54] fix duplicate test name --- .../test/scala/org/apache/spark/sql/types/DataTypeSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 3241f031a706b..d87fe919cf249 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -715,7 +715,7 @@ class DataTypeSuite extends SparkFunSuite { checkEqualsIgnoreCompatibleCollation(StringType, StringType("UTF8_LCASE"), expected = true) checkEqualsIgnoreCompatibleCollation( - StringType("UTF8_BINARY"), StringType("UTF8_LCASE"), expected = true) + StringType("UTF8_LCASE"), StringType("UTF8_BINARY"), expected = true) // Complex types. checkEqualsIgnoreCompatibleCollation( ArrayType(StringType), From fd865902da81ace7277f6c6aa7314fad9781ba01 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Mon, 21 Oct 2024 10:26:19 +0200 Subject: [PATCH 12/54] trigger ci From 99c9dd25b2b2cc085dc27af51562bf13aaaa1a5c Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 22 Oct 2024 09:38:45 +0200 Subject: [PATCH 13/54] trigger ci From 32d1d1ba182ed1386bd9956ba2dffa950683f7b6 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 22 Oct 2024 09:57:41 +0200 Subject: [PATCH 14/54] add more tests --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- ...e.scala => ResolveDefaultStringType.scala} | 7 ++- .../collation/DefaultCollationTestSuite.scala | 49 +++++++++++++++++++ 3 files changed, 53 insertions(+), 5 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/{ReplaceDefaultStringType.scala => ResolveDefaultStringType.scala} (94%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6e75f800e2f5a..5fd2ddf43e3c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -314,7 +314,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveProcedures :: BindProcedures :: ResolveTableSpec :: - ReplaceDefaultStringType :: + ResolveDefaultStringType :: ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala similarity index 94% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index cc8de23c416f5..ddbc40cb09231 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -24,11 +24,11 @@ import org.apache.spark.sql.types.{DataType, DefaultStringType, StringType} import org.apache.spark.sql.util.SchemaUtils /** - * Replaces default string types in DDL commands. DDL commands should have a default collation + * Resolves default string types in DDL commands. DDL commands should have a default collation * based on the object's collation, however, this is not implemented yet. So, we will just use * UTF8_BINARY for now. */ -object ReplaceDefaultStringType extends Rule[LogicalPlan] { +object ResolveDefaultStringType extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { plan resolveOperators { case createTable: V2CreateTablePlan => @@ -66,8 +66,7 @@ object ReplaceDefaultStringType extends Rule[LogicalPlan] { cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType)) case Literal(value, dt) if SchemaUtils.hasDefaultStringType(dt) => - val replaced = replaceDefaultStringType(dt, newType) - Literal(value, replaced) + Literal(value, replaceDefaultStringType(dt, newType)) case other => other } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index edc8ae9482f78..b0cd4b0c31f5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -140,6 +140,35 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { } } + test("ctas with union") { + val tableName = "tbl_union" + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s""" + |CREATE TABLE $tableName USING $dataSource AS + |SELECT 'a' = 'A' AS c1 + |UNION + |SELECT 'b' = 'B' AS c1 + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableName"), + Seq(Row(false))) + } + + withSessionCollationAndTable("UTF8_LCASE", tableName) { + sql(s""" + |CREATE TABLE $tableName USING $dataSource AS + |SELECT 'a' = 'A' AS c1 + |UNION ALL + |SELECT 'b' = 'B' AS c1 + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableName"), + Seq(Row(false), Row(false))) + } + } + test("add column") { val tableName = "tbl_add_col" withSessionCollationAndTable("UTF8_LCASE", tableName) { @@ -249,6 +278,26 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { // scalastyle:on nonascii } + test("union operation with subqueries") { + withSessionCollation("UTF8_LCASE") { + checkAnswer( + sql(s""" + |SELECT 'a' = 'A' + |UNION + |SELECT 'b' = 'B' + |""".stripMargin), + Seq(Row(true))) + + checkAnswer( + sql(s""" + |SELECT 'a' = 'A' + |UNION ALL + |SELECT 'b' = 'B' + |""".stripMargin), + Seq(Row(true), Row(true))) + } + } + test("literals in insert inherit session level collation") { val tableName = "tbl_insert" withSessionCollationAndTable("UTF8_LCASE", tableName) { From 29f9a18252f68069b5827c37e3665bd1e30c65f9 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 23 Oct 2024 18:44:40 +0200 Subject: [PATCH 15/54] add support for create/alter view --- .../analysis/ResolveDefaultStringType.scala | 14 ++- .../spark/sql/execution/datasources/ddl.scala | 2 +- .../collation/DefaultCollationTestSuite.scala | 108 +++++++++++++++++- 3 files changed, 114 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index ddbc40cb09231..f92b37c5dbe55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, ColumnDefinition, LogicalPlan, QualifiedColType, ReplaceColumns, V1DDLCommand, V2CreateTablePlan} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1DDLCommand, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.{DataType, DefaultStringType, StringType} import org.apache.spark.sql.util.SchemaUtils @@ -31,11 +31,8 @@ import org.apache.spark.sql.util.SchemaUtils object ResolveDefaultStringType extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { plan resolveOperators { - case createTable: V2CreateTablePlan => - transformPlan(createTable, StringType) - - case v1Ddl: V1DDLCommand => - transformPlan(v1Ddl, StringType) + case p if isCreateOrAlterPlan(p) => + transformPlan(p, StringType) case addCols: AddColumns => addCols.copy(columnsToAdd = replaceColumnTypes(addCols.columnsToAdd, StringType)) @@ -49,6 +46,11 @@ object ResolveDefaultStringType extends Rule[LogicalPlan] { } } + private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { + case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs | _: V1DDLCommand => true + case _ => false + } + private def transformPlan(plan: LogicalPlan, newType: StringType): LogicalPlan = { plan resolveOperators { operator => operator.transformExpressionsUp { expression => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 3be77b4cb94ef..69186ea4010dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -79,7 +79,7 @@ case class CreateTempViewUsing( replace: Boolean, global: Boolean, provider: String, - options: Map[String, String]) extends LeafRunnableCommand with V1DDLCommand { + options: Map[String, String]) extends LeafRunnableCommand { if (tableIdent.database.isDefined) { throw QueryCompilationErrors.cannotSpecifyDatabaseForTempViewError(tableIdent) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index b0cd4b0c31f5a..ff58dc22e2f1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.collation -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.connector.DatasourceV2SQLBase import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.StringType @@ -26,8 +26,16 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { val dataSource: String = "parquet" - def withSessionCollationAndTable(collation: String, tableName: String)(f: => Unit): Unit = { - withTable(tableName) { + def withSessionCollationAndTable(collation: String, tableNames: String*)(f: => Unit): Unit = { + withTable(tableNames: _*) { + withSessionCollation(collation) { + f + } + } + } + + def withSessionCollationAndView(collation: String, viewNames: String*)(f: => Unit): Unit = { + withView(viewNames: _*) { withSessionCollation(collation) { f } @@ -48,6 +56,13 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { assert(colType === StringType(expectedCollation)) } + def assertThrowsImplicitMismatch(f: => DataFrame): Unit = { + val exception = intercept[AnalysisException] { + f + } + assert(exception.getCondition === "COLLATION_MISMATCH.IMPLICIT") + } + // region DDL tests test("create/alter table") { @@ -183,6 +198,93 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { } } + test("create/alter view created from a table") { + val tableName = "tbl_view" + val viewName = "view2" + val sessionCollation = "UTF8_LCASE" + withSessionCollationAndTable(sessionCollation, tableName) { + sql(s"CREATE TABLE $tableName (c1 STRING, c2 STRING COLLATE UNICODE_CI) USING $dataSource") + sql(s"INSERT INTO $tableName VALUES ('a', 'a'), ('A', 'A')") + + withView(viewName) { + sql(s"CREATE VIEW $viewName AS SELECT * FROM $tableName") + + assertTableColumnCollation(viewName, "c1", "UTF8_BINARY") + assertTableColumnCollation(viewName, "c2", "UNICODE_CI") + checkAnswer( + sql(s"SELECT DISTINCT COLLATION(c1), COLLATION('a') FROM $viewName"), + Row("UTF8_BINARY", sessionCollation) + ) + + // filter should use session collation + checkAnswer( + sql(s"SELECT COUNT(*) FROM $viewName WHERE 'a' = 'A'"), + Row(2) + ) + + // filter should use column collation + checkAnswer( + sql(s"SELECT COUNT(*) FROM $viewName WHERE c1 = 'A'"), + Row(1) + ) + + // literal with explicit collation wins + checkAnswer( + sql(s"SELECT COUNT(*) FROM $viewName WHERE c1 = 'A' collate UNICODE_CI"), + Row(2) + ) + + // two implicit collations -> errors out + assertThrowsImplicitMismatch(sql(s"SELECT c1 = substring('A', 0, 1) FROM $viewName")) + + sql(s"ALTER VIEW $viewName AS SELECT c1 COLLATE UNICODE_CI AS c1, c2 FROM $tableName") + assertTableColumnCollation(viewName, "c1", "UNICODE_CI") + assertTableColumnCollation(viewName, "c2", "UNICODE_CI") + checkAnswer( + sql(s"SELECT DISTINCT COLLATION(c1), COLLATION('a') FROM $viewName"), + Row("UNICODE_CI", sessionCollation) + ) + + // after alter both rows should be returned + checkAnswer( + sql(s"SELECT COUNT(*) FROM $viewName WHERE c1 = 'A'"), + Row(2) + ) + } + } + } + + test("join view with table") { + val viewTableName = "view_table" + val joinTableName = "join_table" + val viewName = "view" + val sessionCollation = "sr" + + withSessionCollationAndTable(sessionCollation, viewTableName, joinTableName) { + sql(s"CREATE TABLE $viewTableName (c1 STRING COLLATE UNICODE_CI) USING $dataSource") + sql(s"CREATE TABLE $joinTableName (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") + sql(s"INSERT INTO $viewTableName VALUES ('a')") + sql(s"INSERT INTO $joinTableName VALUES ('A')") + + withView(viewName) { + sql(s"CREATE VIEW $viewName AS SELECT * FROM $viewTableName") + + assertThrowsImplicitMismatch( + sql(s"SELECT * FROM $viewName JOIN $joinTableName ON $viewName.c1 = $joinTableName.c1") + ) + + checkAnswer( + sql(s""" + |SELECT COLLATION($viewName.c1), COLLATION($joinTableName.c1) + |FROM $viewName JOIN $joinTableName + |ON $viewName.c1 = $joinTableName.c1 COLLATE UNICODE_CI + |""".stripMargin), + Row("UNICODE_CI", "UTF8_LCASE") + ) + } + } + } + // endregion // region DML tests From 8df3263d3646a4e338b4600f4f17f469c8631760 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 23 Oct 2024 18:46:49 +0200 Subject: [PATCH 16/54] remove explicit collation in map access --- .../apache/spark/sql/collation/DefaultCollationTestSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index ff58dc22e2f1c..5c411d311214c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -148,8 +148,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { checkAnswer(sql(s"SELECT COLLATION(c1.col1) FROM $tableName"), Seq(Row("UTF8_BINARY"))) checkAnswer( - // TODO: other PR is supposed to fix explicit collation here - sql(s"SELECT COLLATION(c2['a' collate UTF8_BINARY]) FROM $tableName"), + sql(s"SELECT COLLATION(c2['a']) FROM $tableName"), Seq(Row("UTF8_BINARY"))) checkAnswer(sql(s"SELECT COLLATION(c3[0]) FROM $tableName"), Seq(Row("UTF8_BINARY"))) } From fd541a008d34ab627c6580ffde05de56671a2a64 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 29 Oct 2024 20:25:23 +0100 Subject: [PATCH 17/54] do not break parser for StringType --- .../catalyst/parser/DataTypeAstBuilder.scala | 6 ++-- .../apache/spark/sql/types/StringType.scala | 18 ++++------ .../analysis/CollationTypeCasts.scala | 5 +-- .../analysis/ResolveDefaultStringType.scala | 36 +++++++++++++------ .../apache/spark/sql/util/SchemaUtils.scala | 9 +---- .../collation/DefaultCollationTestSuite.scala | 34 ++++++++++-------- 6 files changed, 59 insertions(+), 49 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index b0e7e5846656c..7986934f6ddc4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin} import org.apache.spark.sql.connector.catalog.IdentityColumnSpec import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DefaultStringType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StronglyTypedStringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { protected def typedVisit[T](ctx: ParseTree): T = { @@ -76,11 +76,11 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { case (TIMESTAMP_LTZ, Nil) => TimestampType case (STRING, Nil) => typeCtx.children.asScala.toSeq match { - case Seq(_) => DefaultStringType() + case Seq(_) => SqlApiConf.get.defaultStringType case Seq(_, ctx: CollateClauseContext) => val collationName = visitCollateClause(ctx) val collationId = CollationFactory.collationNameToId(collationName) - StringType(collationId) + StronglyTypedStringType(collationId) } case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt) case (VARCHAR, length :: Nil) => VarcharType(length.getText.toInt) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 27836e128cee3..a1e1f66ca0e72 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -21,7 +21,6 @@ import org.json4s.JsonAST.{JString, JValue} import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.CollationFactory -import org.apache.spark.sql.internal.SqlApiConf /** * The data type representing `String` values. Please use the singleton `DataTypes.StringType`. @@ -116,17 +115,14 @@ case object StringType extends StringType(0) { } /** - * The result type of literals, column definitions without explicit collation, casts to string and - * some expressions that produce strings but whose output type is not based on the types of its - * children. Idea is to have this behave like a string with the default collation of the session, - * but that we can still differentiate it from a regular string type, because in some places - * default string is not the one with the session collation (e.g. in DDL commands). + * String type which result from a strongly typed string declaration + * `STRING COLLATE ...`, and not just implicit `STRING`. */ -private[spark] class DefaultStringType private (collationId: Int) - extends StringType(collationId) {} +private[spark] class StronglyTypedStringType private(collationId: Int) + extends StringType(collationId) {} -private[spark] object DefaultStringType { - def apply(): DefaultStringType = { - new DefaultStringType(SqlApiConf.get.defaultStringType.collationId) +private[spark] object StronglyTypedStringType { + def apply(id: Int): StronglyTypedStringType = { + new StronglyTypedStringType(id) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 3c5cf6fb2520f..bd80df0e8bf52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -24,7 +24,8 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{ArrayType, DataType, DefaultStringType, MapType, StringType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType} object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { @@ -203,7 +204,7 @@ object CollationTypeCasts extends TypeCoercionRule { ) } else { - implicitTypes.headOption.map(StringType(_)).getOrElse(DefaultStringType()) + implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index f92b37c5dbe55..554037ef2d362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1DDLCommand, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.{DataType, DefaultStringType, StringType} -import org.apache.spark.sql.util.SchemaUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StringType, StronglyTypedStringType} /** - * Resolves default string types in DDL commands. DDL commands should have a default collation - * based on the object's collation, however, this is not implemented yet. So, we will just use - * UTF8_BINARY for now. + * Resolves default string types in DDL commands. For DML commands, the default string type is + * determined by the session's default string type. For DDL, the default string type is + * the default type of the object (table -> schema -> catalog). + * However, this is not implemented yet. So, we will just use UTF8_BINARY for now. */ object ResolveDefaultStringType extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { @@ -41,7 +42,7 @@ object ResolveDefaultStringType extends Rule[LogicalPlan] { replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, StringType)) case alter: AlterColumn - if alter.dataType.isDefined && SchemaUtils.hasDefaultStringType(alter.dataType.get) => + if alter.dataType.isDefined && hasDefaultStringType(alter.dataType.get) => alter.copy(dataType = Some(replaceDefaultStringType(alter.dataType.get, StringType))) } } @@ -61,22 +62,35 @@ object ResolveDefaultStringType extends Rule[LogicalPlan] { private def transformExpression(expression: Expression, newType: StringType): Expression = { expression match { - case columnDef: ColumnDefinition if SchemaUtils.hasDefaultStringType(columnDef.dataType) => + case columnDef: ColumnDefinition if hasDefaultStringType(columnDef.dataType) => columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType)) - case cast: Cast if SchemaUtils.hasDefaultStringType(cast.dataType) => + case cast: Cast if hasDefaultStringType(cast.dataType) => cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType)) - case Literal(value, dt) if SchemaUtils.hasDefaultStringType(dt) => + case Literal(value, dt) if hasDefaultStringType(dt) => Literal(value, replaceDefaultStringType(dt, newType)) case other => other } } + private def hasDefaultStringType(dataType: DataType): Boolean = + dataType.existsRecursively(isDefaultStringType) + + private def isDefaultStringType(dataType: DataType): Boolean = { + dataType match { + case st: StringType => + val sessionCollation = SQLConf.get.defaultStringType + st == sessionCollation && !st.isInstanceOf[StronglyTypedStringType] + case _ => false + } + } + private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = { dataType.transformRecursively { - case _: DefaultStringType => newType + case _ if isDefaultStringType(dataType) => + newType } } @@ -84,7 +98,7 @@ object ResolveDefaultStringType extends Rule[LogicalPlan] { colTypes: Seq[QualifiedColType], newType: StringType): Seq[QualifiedColType] = { colTypes.map { - case colWithDefault if SchemaUtils.hasDefaultStringType(colWithDefault.dataType) => + case colWithDefault if hasDefaultStringType(colWithDefault.dataType) => val replaced = replaceDefaultStringType(colWithDefault.dataType, newType) colWithDefault.copy(dataType = replaced) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index 277ae0a92bdf5..1e0bac331dc75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression} import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.types.{ArrayType, DataType, DefaultStringType, MapType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType, StructField, StructType} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkSchemaUtils @@ -304,13 +304,6 @@ private[spark] object SchemaUtils { } } - /** - * Checks if a given data type has a default string type. - */ - def hasDefaultStringType(dataType: DataType): Boolean = { - dataType.existsRecursively(_.isInstanceOf[DefaultStringType]) - } - /** * Replaces any collated string type with non collated StringType * recursively in the given data type. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index 5c411d311214c..ad99cbe5083ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -289,16 +289,17 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { // region DML tests test("literals with default collation") { - withSessionCollation("UTF8_LCASE") { + val sessionCollation = "UTF8_LCASE" + withSessionCollation(sessionCollation) { // literal without collation - checkAnswer(sql("SELECT COLLATION('a')"), Seq(Row("UTF8_LCASE"))) + checkAnswer(sql("SELECT COLLATION('a')"), Seq(Row(sessionCollation))) - checkAnswer(sql("SELECT COLLATION(map('a', 'b')['a'])"), Seq(Row("UTF8_LCASE"))) + checkAnswer(sql("SELECT COLLATION(map('a', 'b')['a'])"), Seq(Row(sessionCollation))) - checkAnswer(sql("SELECT COLLATION(array('a')[0])"), Seq(Row("UTF8_LCASE"))) + checkAnswer(sql("SELECT COLLATION(array('a')[0])"), Seq(Row(sessionCollation))) - checkAnswer(sql("SELECT COLLATION(struct('a' as c)['c'])"), Seq(Row("UTF8_LCASE"))) + checkAnswer(sql("SELECT COLLATION(struct('a' as c)['c'])"), Seq(Row(sessionCollation))) } } @@ -319,22 +320,27 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { } test("cast is aware of session collation") { - withSessionCollation("UTF8_LCASE") { + val sessionCollation = "UTF8_LCASE" + withSessionCollation(sessionCollation) { + checkAnswer( + sql("SELECT COLLATION(cast('a' as STRING))"), + Seq(Row(sessionCollation))) + checkAnswer( - sql("SELECT COLLATION(cast('a' collate unicode as STRING))"), - Seq(Row("UTF8_LCASE"))) + sql("SELECT COLLATION(cast(map('a', 'b') as MAP)['a'])"), + Seq(Row(sessionCollation))) checkAnswer( - sql("SELECT COLLATION(cast(map('a', 'b' collate unicode) as MAP)['a'])"), - Seq(Row("UTF8_LCASE"))) + sql("SELECT COLLATION(map_keys(cast(map('a', 'b') as MAP))[0])"), + Seq(Row(sessionCollation))) checkAnswer( - sql("SELECT COLLATION(cast(array('a' collate unicode) as ARRAY)[0])"), - Seq(Row("UTF8_LCASE"))) + sql("SELECT COLLATION(cast(array('a') as ARRAY)[0])"), + Seq(Row(sessionCollation))) checkAnswer( - sql("SELECT COLLATION(cast(struct('a' collate unicode as c) as STRUCT)['c'])"), - Seq(Row("UTF8_LCASE"))) + sql("SELECT COLLATION(cast(struct('a' as c) as STRUCT)['c'])"), + Seq(Row(sessionCollation))) } } From e332022b77484ca2a35ae3f90a205c5b7feb5554 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 29 Oct 2024 20:55:40 +0100 Subject: [PATCH 18/54] fix compilation err --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5603474488275..da0db687e7d2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2088,7 +2088,7 @@ class AstBuilder extends DataTypeAstBuilder */ private def canEagerlyEvaluateInlineTable(rows: Seq[Seq[Expression]]): Boolean = { val configSet = conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED) - val sessionCollationUnchanged = DefaultStringType() == StringType + val sessionCollationUnchanged = conf.defaultStringType == StringType configSet && (sessionCollationUnchanged || !rows.exists(_.exists(containsStringLiteral))) @@ -3315,7 +3315,7 @@ class AstBuilder extends DataTypeAstBuilder * Create a String literal expression. */ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { - Literal.create(createString(ctx), DefaultStringType()) + Literal.create(createString(ctx), conf.defaultStringType) } /** From 7fafc99e533853cc6d5de9c0d2b0b5e182fc9588 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 29 Oct 2024 21:40:21 +0100 Subject: [PATCH 19/54] fmt --- .../apache/spark/sql/types/StringType.scala | 8 ++-- .../analysis/ResolveDefaultStringType.scala | 6 +-- .../collation/DefaultCollationTestSuite.scala | 46 +++++-------------- 3 files changed, 19 insertions(+), 41 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index a1e1f66ca0e72..a563276742a56 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -115,11 +115,11 @@ case object StringType extends StringType(0) { } /** - * String type which result from a strongly typed string declaration - * `STRING COLLATE ...`, and not just implicit `STRING`. + * String type which result from a strongly typed string declaration `STRING COLLATE ...`, and not + * just implicit `STRING`. */ -private[spark] class StronglyTypedStringType private(collationId: Int) - extends StringType(collationId) {} +private[spark] class StronglyTypedStringType private (collationId: Int) + extends StringType(collationId) {} private[spark] object StronglyTypedStringType { def apply(id: Int): StronglyTypedStringType = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index 554037ef2d362..0a9667ff57d35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.types.{DataType, StringType, StronglyTypedStringType /** * Resolves default string types in DDL commands. For DML commands, the default string type is - * determined by the session's default string type. For DDL, the default string type is - * the default type of the object (table -> schema -> catalog). - * However, this is not implemented yet. So, we will just use UTF8_BINARY for now. + * determined by the session's default string type. For DDL, the default string type is the + * default type of the object (table -> schema -> catalog). However, this is not implemented yet. + * So, we will just use UTF8_BINARY for now. */ object ResolveDefaultStringType extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index ad99cbe5083ff..5fe16570c11e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -147,9 +147,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { |""".stripMargin) checkAnswer(sql(s"SELECT COLLATION(c1.col1) FROM $tableName"), Seq(Row("UTF8_BINARY"))) - checkAnswer( - sql(s"SELECT COLLATION(c2['a']) FROM $tableName"), - Seq(Row("UTF8_BINARY"))) + checkAnswer(sql(s"SELECT COLLATION(c2['a']) FROM $tableName"), Seq(Row("UTF8_BINARY"))) checkAnswer(sql(s"SELECT COLLATION(c3[0]) FROM $tableName"), Seq(Row("UTF8_BINARY"))) } } @@ -164,9 +162,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { |SELECT 'b' = 'B' AS c1 |""".stripMargin) - checkAnswer( - sql(s"SELECT * FROM $tableName"), - Seq(Row(false))) + checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row(false))) } withSessionCollationAndTable("UTF8_LCASE", tableName) { @@ -177,9 +173,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { |SELECT 'b' = 'B' AS c1 |""".stripMargin) - checkAnswer( - sql(s"SELECT * FROM $tableName"), - Seq(Row(false), Row(false))) + checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row(false), Row(false))) } } @@ -212,26 +206,18 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { assertTableColumnCollation(viewName, "c2", "UNICODE_CI") checkAnswer( sql(s"SELECT DISTINCT COLLATION(c1), COLLATION('a') FROM $viewName"), - Row("UTF8_BINARY", sessionCollation) - ) + Row("UTF8_BINARY", sessionCollation)) // filter should use session collation - checkAnswer( - sql(s"SELECT COUNT(*) FROM $viewName WHERE 'a' = 'A'"), - Row(2) - ) + checkAnswer(sql(s"SELECT COUNT(*) FROM $viewName WHERE 'a' = 'A'"), Row(2)) // filter should use column collation - checkAnswer( - sql(s"SELECT COUNT(*) FROM $viewName WHERE c1 = 'A'"), - Row(1) - ) + checkAnswer(sql(s"SELECT COUNT(*) FROM $viewName WHERE c1 = 'A'"), Row(1)) // literal with explicit collation wins checkAnswer( sql(s"SELECT COUNT(*) FROM $viewName WHERE c1 = 'A' collate UNICODE_CI"), - Row(2) - ) + Row(2)) // two implicit collations -> errors out assertThrowsImplicitMismatch(sql(s"SELECT c1 = substring('A', 0, 1) FROM $viewName")) @@ -241,14 +227,10 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { assertTableColumnCollation(viewName, "c2", "UNICODE_CI") checkAnswer( sql(s"SELECT DISTINCT COLLATION(c1), COLLATION('a') FROM $viewName"), - Row("UNICODE_CI", sessionCollation) - ) + Row("UNICODE_CI", sessionCollation)) // after alter both rows should be returned - checkAnswer( - sql(s"SELECT COUNT(*) FROM $viewName WHERE c1 = 'A'"), - Row(2) - ) + checkAnswer(sql(s"SELECT COUNT(*) FROM $viewName WHERE c1 = 'A'"), Row(2)) } } } @@ -269,8 +251,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { sql(s"CREATE VIEW $viewName AS SELECT * FROM $viewTableName") assertThrowsImplicitMismatch( - sql(s"SELECT * FROM $viewName JOIN $joinTableName ON $viewName.c1 = $joinTableName.c1") - ) + sql(s"SELECT * FROM $viewName JOIN $joinTableName ON $viewName.c1 = $joinTableName.c1")) checkAnswer( sql(s""" @@ -278,8 +259,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { |FROM $viewName JOIN $joinTableName |ON $viewName.c1 = $joinTableName.c1 COLLATE UNICODE_CI |""".stripMargin), - Row("UNICODE_CI", "UTF8_LCASE") - ) + Row("UNICODE_CI", "UTF8_LCASE")) } } } @@ -322,9 +302,7 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { test("cast is aware of session collation") { val sessionCollation = "UTF8_LCASE" withSessionCollation(sessionCollation) { - checkAnswer( - sql("SELECT COLLATION(cast('a' as STRING))"), - Seq(Row(sessionCollation))) + checkAnswer(sql("SELECT COLLATION(cast('a' as STRING))"), Seq(Row(sessionCollation))) checkAnswer( sql("SELECT COLLATION(cast(map('a', 'b') as MAP)['a'])"), From 175f70337bc923470aca1ae49b92c22848e98c64 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 13 Nov 2024 16:31:39 +0100 Subject: [PATCH 20/54] remove StronglyTypedStringType --- .../catalyst/parser/DataTypeAstBuilder.scala | 4 +-- .../apache/spark/sql/types/StringType.scala | 21 ++++++++------ .../analysis/ResolveDefaultStringType.scala | 28 ++++++++++++++----- .../apache/spark/sql/internal/SQLConf.scala | 4 +-- 4 files changed, 38 insertions(+), 19 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 7986934f6ddc4..71e8517a4164e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin} import org.apache.spark.sql.connector.catalog.IdentityColumnSpec import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StronglyTypedStringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { protected def typedVisit[T](ctx: ParseTree): T = { @@ -80,7 +80,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { case Seq(_, ctx: CollateClauseContext) => val collationName = visitCollateClause(ctx) val collationId = CollationFactory.collationNameToId(collationName) - StronglyTypedStringType(collationId) + StringType(collationId) } case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt) case (VARCHAR, length :: Nil) => VarcharType(length.getText.toInt) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index a563276742a56..fd53071a965c4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -115,14 +115,19 @@ case object StringType extends StringType(0) { } /** - * String type which result from a strongly typed string declaration `STRING COLLATE ...`, and not - * just implicit `STRING`. + * Data type representing string with default collation (what is default + * collation depends on whether it is a dml or ddl statement). + * + * In dml statements the default collation is determined by the session collation, + * while in ddl statements the default collation is determined by collation of the object + * (table, view, etc.). */ -private[spark] class StronglyTypedStringType private (collationId: Int) - extends StringType(collationId) {} +private[spark] case class DefaultStringType(override val collationId: Int) + extends StringType(collationId) -private[spark] object StronglyTypedStringType { - def apply(id: Int): StronglyTypedStringType = { - new StronglyTypedStringType(id) +private[spark] object DefaultStringType { + def apply(collationName: String): DefaultStringType = { + val collationId = CollationFactory.collationNameToId(collationName) + new DefaultStringType(collationId) } -} +} \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index 0a9667ff57d35..68b3bb896a7c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1DDLCommand, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StringType, StronglyTypedStringType} +import org.apache.spark.sql.types.{DataType, DefaultStringType, StringType} /** * Resolves default string types in DDL commands. For DML commands, the default string type is @@ -31,22 +31,35 @@ import org.apache.spark.sql.types.{DataType, StringType, StronglyTypedStringType */ object ResolveDefaultStringType extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { +// if (isDefaultSessionCollationUsed) { +// return plan +// } + plan resolveOperators { case p if isCreateOrAlterPlan(p) => - transformPlan(p, StringType) + val newType = stringTypeForDDLCommand(p) + transformPlan(p, newType) case addCols: AddColumns => - addCols.copy(columnsToAdd = replaceColumnTypes(addCols.columnsToAdd, StringType)) + val newType = stringTypeForDDLCommand(addCols.table) + addCols.copy(columnsToAdd = replaceColumnTypes(addCols.columnsToAdd, newType)) case replaceCols: ReplaceColumns => - replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, StringType)) + val newType = stringTypeForDDLCommand(replaceCols.table) + replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, newType)) case alter: AlterColumn if alter.dataType.isDefined && hasDefaultStringType(alter.dataType.get) => - alter.copy(dataType = Some(replaceDefaultStringType(alter.dataType.get, StringType))) + val newType = stringTypeForDDLCommand(alter.table) + alter.copy(dataType = Some(replaceDefaultStringType(alter.dataType.get, newType))) } } + private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType + + private def stringTypeForDDLCommand(table: LogicalPlan): StringType = + StringType + private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs | _: V1DDLCommand => true case _ => false @@ -80,9 +93,10 @@ object ResolveDefaultStringType extends Rule[LogicalPlan] { private def isDefaultStringType(dataType: DataType): Boolean = { dataType match { + case _: DefaultStringType => true case st: StringType => - val sessionCollation = SQLConf.get.defaultStringType - st == sessionCollation && !st.isInstanceOf[StronglyTypedStringType] + // should only return true for StringType object and not StringType("UTF8_BINARY") + st.eq(StringType) case _ => false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1c9f5e85d1a06..8b4bd705ec296 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.catalyst.plans.logical.HintErrorHandler import org.apache.spark.sql.catalyst.util.{CollationFactory, DateTimeUtils} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.types.{AtomicType, StringType, TimestampNTZType, TimestampType} +import org.apache.spark.sql.types.{AtomicType, DefaultStringType, StringType, TimestampNTZType, TimestampType} import org.apache.spark.storage.{StorageLevel, StorageLevelMapper} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.{Utils, VersionUtils} @@ -5508,7 +5508,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf { if (getConf(DEFAULT_COLLATION).toUpperCase(Locale.ROOT) == "UTF8_BINARY") { StringType } else { - StringType(CollationFactory.collationNameToId(getConf(DEFAULT_COLLATION))) + DefaultStringType(getConf(DEFAULT_COLLATION)) } } From d717fed194f88c5e9b608aa3af9b93c51b19c7f0 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 13 Nov 2024 16:33:44 +0100 Subject: [PATCH 21/54] fix small bug in transform --- .../spark/sql/catalyst/analysis/ResolveDefaultStringType.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index 68b3bb896a7c8..920454fcb4ac9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -103,7 +103,7 @@ object ResolveDefaultStringType extends Rule[LogicalPlan] { private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = { dataType.transformRecursively { - case _ if isDefaultStringType(dataType) => + case currentType if isDefaultStringType(currentType) => newType } } From 143d2c0ac8dc7902d498579ae07ff3fc1f605c74 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 13 Nov 2024 16:35:10 +0100 Subject: [PATCH 22/54] fix scalastyle --- .../src/main/scala/org/apache/spark/sql/types/StringType.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index fd53071a965c4..87fd76f0c46f4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -130,4 +130,4 @@ private[spark] object DefaultStringType { val collationId = CollationFactory.collationNameToId(collationName) new DefaultStringType(collationId) } -} \ No newline at end of file +} From 0138f3ab11bdea925a634d36753bdcc1b0c7062c Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 13 Nov 2024 16:52:03 +0100 Subject: [PATCH 23/54] add docstring --- .../analysis/ResolveDefaultStringType.scala | 13 ++++++++++++- .../collation/DefaultCollationTestSuite.scala | 18 +++++++++--------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index 920454fcb4ac9..450d9c546aa51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -57,6 +57,10 @@ object ResolveDefaultStringType extends Rule[LogicalPlan] { private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType + /** + * Returns the default string type that should be used in a given DDL command + * (for now always UTF8_BINARY). + */ private def stringTypeForDDLCommand(table: LogicalPlan): StringType = StringType @@ -65,14 +69,21 @@ object ResolveDefaultStringType extends Rule[LogicalPlan] { case _ => false } + /** + * Transforms the given plan, by transforming all expressions in its operators to use the given + * new type instead of the default string type. + */ private def transformPlan(plan: LogicalPlan, newType: StringType): LogicalPlan = { plan resolveOperators { operator => - operator.transformExpressionsUp { expression => + operator resolveExpressionsUp { expression => transformExpression(expression, newType) } } } + /** + * Transforms the given expression, by changing all default string types to the given new type. + */ private def transformExpression(expression: Expression, newType: StringType): Expression = { expression match { case columnDef: ColumnDefinition if hasDefaultStringType(columnDef.dataType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index 5fe16570c11e3..5f766052e1294 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -167,11 +167,11 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { withSessionCollationAndTable("UTF8_LCASE", tableName) { sql(s""" - |CREATE TABLE $tableName USING $dataSource AS - |SELECT 'a' = 'A' AS c1 - |UNION ALL - |SELECT 'b' = 'B' AS c1 - |""".stripMargin) + |CREATE TABLE $tableName USING $dataSource AS + |SELECT 'a' = 'A' AS c1 + |UNION ALL + |SELECT 'b' = 'B' AS c1 + |""".stripMargin) checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row(false), Row(false))) } @@ -375,10 +375,10 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { checkAnswer( sql(s""" - |SELECT 'a' = 'A' - |UNION ALL - |SELECT 'b' = 'B' - |""".stripMargin), + |SELECT 'a' = 'A' + |UNION ALL + |SELECT 'b' = 'B' + |""".stripMargin), Seq(Row(true), Row(true))) } } From 27059a8561f423b86af68a84620719bf004c7b2f Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 13 Nov 2024 19:01:09 +0100 Subject: [PATCH 24/54] add check for not using session collation --- .../scala/org/apache/spark/sql/types/StringType.scala | 11 +++++------ .../catalyst/analysis/ResolveDefaultStringType.scala | 10 +++++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 87fd76f0c46f4..982c1934ede58 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -115,15 +115,14 @@ case object StringType extends StringType(0) { } /** - * Data type representing string with default collation (what is default - * collation depends on whether it is a dml or ddl statement). + * Data type representing string with default collation (what is default collation depends on + * whether it is a dml or ddl statement). * - * In dml statements the default collation is determined by the session collation, - * while in ddl statements the default collation is determined by collation of the object - * (table, view, etc.). + * In dml statements the default collation is determined by the session collation, while in ddl + * statements the default collation is determined by collation of the object (table, view, etc.). */ private[spark] case class DefaultStringType(override val collationId: Int) - extends StringType(collationId) + extends StringType(collationId) private[spark] object DefaultStringType { def apply(collationName: String): DefaultStringType = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index 450d9c546aa51..a6dbd41684f1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.types.{DataType, DefaultStringType, StringType} */ object ResolveDefaultStringType extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { -// if (isDefaultSessionCollationUsed) { -// return plan -// } + if (isDefaultSessionCollationUsed) { + return plan + } plan resolveOperators { case p if isCreateOrAlterPlan(p) => @@ -58,8 +58,8 @@ object ResolveDefaultStringType extends Rule[LogicalPlan] { private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType /** - * Returns the default string type that should be used in a given DDL command - * (for now always UTF8_BINARY). + * Returns the default string type that should be used in a given DDL command (for now always + * UTF8_BINARY). */ private def stringTypeForDDLCommand(table: LogicalPlan): StringType = StringType From 74414a10a8992d7cd8e4b51967e6508e58b0d3f5 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 14 Nov 2024 12:32:22 +0100 Subject: [PATCH 25/54] inital --- .../catalyst/parser/DataTypeAstBuilder.scala | 2 +- .../apache/spark/sql/types/StringType.scala | 17 ------ .../sql/catalyst/analysis/Analyzer.scala | 3 +- .../analysis/CollationTypeCasts.scala | 3 +- .../analysis/ResolveDefaultStringType.scala | 57 ++++++++++++------- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 4 +- 7 files changed, 44 insertions(+), 44 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 71e8517a4164e..c2cb4a7154076 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -76,7 +76,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { case (TIMESTAMP_LTZ, Nil) => TimestampType case (STRING, Nil) => typeCtx.children.asScala.toSeq match { - case Seq(_) => SqlApiConf.get.defaultStringType + case Seq(_) => StringType case Seq(_, ctx: CollateClauseContext) => val collationName = visitCollateClause(ctx) val collationId = CollationFactory.collationNameToId(collationName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 982c1934ede58..b130453845654 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -113,20 +113,3 @@ case object StringType extends StringType(0) { new StringType(collationId) } } - -/** - * Data type representing string with default collation (what is default collation depends on - * whether it is a dml or ddl statement). - * - * In dml statements the default collation is determined by the session collation, while in ddl - * statements the default collation is determined by collation of the object (table, view, etc.). - */ -private[spark] case class DefaultStringType(override val collationId: Int) - extends StringType(collationId) - -private[spark] object DefaultStringType { - def apply(collationName: String): DefaultStringType = { - val collationId = CollationFactory.collationNameToId(collationName) - new DefaultStringType(collationId) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 418b49c508d35..edd6fc1b1af3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -313,7 +313,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveProcedures :: BindProcedures :: ResolveTableSpec :: - ResolveDefaultStringType :: + new ResolveDefaultStringType(replaceWithTemp = true) :: + new ResolveDefaultStringType(replaceWithTemp = false) :: ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index bd80df0e8bf52..a4832d59b7a6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -24,7 +24,6 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType} object CollationTypeCasts extends TypeCoercionRule { @@ -204,7 +203,7 @@ object CollationTypeCasts extends TypeCoercionRule { ) } else { - implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType) + implicitTypes.headOption.map(StringType(_)).getOrElse(StringType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index a6dbd41684f1d..95367eebc5bf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -21,7 +21,12 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1DDLCommand, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, DefaultStringType, StringType} +import org.apache.spark.sql.types.{DataType, StringType} + + +case class TemporaryStringType( + override val collationId: Int = SQLConf.get.defaultStringType.collationId) + extends StringType(collationId) {} /** * Resolves default string types in DDL commands. For DML commands, the default string type is @@ -29,12 +34,37 @@ import org.apache.spark.sql.types.{DataType, DefaultStringType, StringType} * default type of the object (table -> schema -> catalog). However, this is not implemented yet. * So, we will just use UTF8_BINARY for now. */ -object ResolveDefaultStringType extends Rule[LogicalPlan] { +class ResolveDefaultStringType(replaceWithTemp: Boolean) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { if (isDefaultSessionCollationUsed) { - return plan + plan + } else if (isDDLCommand(plan)) { + transformDDL(plan) + } else { + transformPlan(plan, SQLConf.get.defaultStringType) } + } + + private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType + + /** + * Returns the default string type that should be used in a given DDL command (for now always + * UTF8_BINARY). + */ + private def stringTypeForDDLCommand(table: LogicalPlan): StringType = + new StringType(0) + + private def isDDLCommand(plan: LogicalPlan): Boolean = plan match { + case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true + case _ => isCreateOrAlterPlan(plan) + } + private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { + case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs | _: V1DDLCommand => true + case _ => false + } + + private def transformDDL(plan: LogicalPlan): LogicalPlan = { plan resolveOperators { case p if isCreateOrAlterPlan(p) => val newType = stringTypeForDDLCommand(p) @@ -49,26 +79,12 @@ object ResolveDefaultStringType extends Rule[LogicalPlan] { replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, newType)) case alter: AlterColumn - if alter.dataType.isDefined && hasDefaultStringType(alter.dataType.get) => + if alter.dataType.isDefined && hasDefaultStringType(alter.dataType.get) => val newType = stringTypeForDDLCommand(alter.table) alter.copy(dataType = Some(replaceDefaultStringType(alter.dataType.get, newType))) } } - private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType - - /** - * Returns the default string type that should be used in a given DDL command (for now always - * UTF8_BINARY). - */ - private def stringTypeForDDLCommand(table: LogicalPlan): StringType = - StringType - - private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { - case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs | _: V1DDLCommand => true - case _ => false - } - /** * Transforms the given plan, by transforming all expressions in its operators to use the given * new type instead of the default string type. @@ -104,7 +120,8 @@ object ResolveDefaultStringType extends Rule[LogicalPlan] { private def isDefaultStringType(dataType: DataType): Boolean = { dataType match { - case _: DefaultStringType => true + case _: TemporaryStringType => + !replaceWithTemp case st: StringType => // should only return true for StringType object and not StringType("UTF8_BINARY") st.eq(StringType) @@ -115,7 +132,7 @@ object ResolveDefaultStringType extends Rule[LogicalPlan] { private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = { dataType.transformRecursively { case currentType if isDefaultStringType(currentType) => - newType + if (replaceWithTemp && currentType == newType) TemporaryStringType() else newType } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index da0db687e7d2e..c42a8e1a35ec5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3315,7 +3315,7 @@ class AstBuilder extends DataTypeAstBuilder * Create a String literal expression. */ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { - Literal.create(createString(ctx), conf.defaultStringType) + Literal.create(createString(ctx), StringType) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8b4bd705ec296..c7cfa4aa9daa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.catalyst.plans.logical.HintErrorHandler import org.apache.spark.sql.catalyst.util.{CollationFactory, DateTimeUtils} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.types.{AtomicType, DefaultStringType, StringType, TimestampNTZType, TimestampType} +import org.apache.spark.sql.types.{AtomicType, StringType, TimestampNTZType, TimestampType} import org.apache.spark.storage.{StorageLevel, StorageLevelMapper} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.{Utils, VersionUtils} @@ -5508,7 +5508,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf { if (getConf(DEFAULT_COLLATION).toUpperCase(Locale.ROOT) == "UTF8_BINARY") { StringType } else { - DefaultStringType(getConf(DEFAULT_COLLATION)) + StringType(getConf(DEFAULT_COLLATION)) } } From ddfc137a9a2ee5d295644589729486b80f4b7549 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 14 Nov 2024 17:04:40 +0100 Subject: [PATCH 26/54] add v1 and v2 tests --- .../analysis/ResolveDefaultStringType.scala | 4 +- .../catalyst/plans/logical/v2Commands.scala | 6 - .../spark/sql/execution/datasources/ddl.scala | 4 +- .../collation/DefaultCollationTestSuite.scala | 315 +++++++++--------- 4 files changed, 161 insertions(+), 168 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index 95367eebc5bf8..a5a75fb2118c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1DDLCommand, V2CreateTablePlan} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StringType} @@ -60,7 +60,7 @@ class ResolveDefaultStringType(replaceWithTemp: Boolean) extends Rule[LogicalPla } private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { - case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs | _: V1DDLCommand => true + case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index f6eba1552cfff..b465e0e11612f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -88,12 +88,6 @@ trait V2PartitionCommand extends UnaryCommand { override def child: LogicalPlan = table } -/** - * Trait to add to v1 ddl commands so that we can still catch those plan nodes - * in the catalyst analyzer rules. - */ -trait V1DDLCommand {} - /** * Append data to an existing table. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 69186ea4010dc..d9367d92d462e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V1DDLCommand} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._ case class CreateTable( tableDesc: CatalogTable, mode: SaveMode, - query: Option[LogicalPlan]) extends LogicalPlan with V1DDLCommand { + query: Option[LogicalPlan]) extends LogicalPlan { assert(tableDesc.provider.isDefined, "The table to be created must have a provider.") if (query.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index 5f766052e1294..7eafbc580e972 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -17,17 +17,20 @@ package org.apache.spark.sql.collation -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.connector.DatasourceV2SQLBase import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StringType -class DefaultCollationTestSuite extends DatasourceV2SQLBase { +abstract class DefaultCollationTestSuite extends QueryTest with SharedSparkSession { - val dataSource: String = "parquet" + def dataSource: String = "parquet" + def testTable: String = "test_tbl" + def testView: String = "test_view" - def withSessionCollationAndTable(collation: String, tableNames: String*)(f: => Unit): Unit = { - withTable(tableNames: _*) { + def withSessionCollationAndTable(collation: String, testTables: String*)(f: => Unit): Unit = { + withTable(testTables: _*) { withSessionCollation(collation) { f } @@ -66,201 +69,121 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { // region DDL tests test("create/alter table") { - val tableName = "tbl" - withSessionCollationAndTable("UTF8_LCASE", tableName) { + withSessionCollationAndTable("UTF8_LCASE", testTable) { // create table with implicit collation - sql(s"CREATE TABLE $tableName (c1 STRING) USING $dataSource") - assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") + sql(s"CREATE TABLE $testTable (c1 STRING) USING $dataSource") + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") // alter table add column with implicit collation - sql(s"ALTER TABLE $tableName ADD COLUMN c2 STRING") - assertTableColumnCollation(tableName, "c2", "UTF8_BINARY") + sql(s"ALTER TABLE $testTable ADD COLUMN c2 STRING") + assertTableColumnCollation(testTable, "c2", "UTF8_BINARY") - sql(s"ALTER TABLE $tableName ALTER COLUMN c2 TYPE STRING COLLATE UNICODE") - assertTableColumnCollation(tableName, "c2", "UNICODE") + sql(s"ALTER TABLE $testTable ALTER COLUMN c2 TYPE STRING COLLATE UNICODE") + assertTableColumnCollation(testTable, "c2", "UNICODE") - sql(s"ALTER TABLE $tableName ALTER COLUMN c2 TYPE STRING") - assertTableColumnCollation(tableName, "c2", "UTF8_BINARY") + sql(s"ALTER TABLE $testTable ALTER COLUMN c2 TYPE STRING") + assertTableColumnCollation(testTable, "c2", "UTF8_BINARY") } } test("create table with explicit collation") { - val tableName = "tbl_explicit_collation" - withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") - assertTableColumnCollation(tableName, "c1", "UTF8_LCASE") + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") + assertTableColumnCollation(testTable, "c1", "UTF8_LCASE") } - withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName (c1 STRING COLLATE UNICODE) USING $dataSource") - assertTableColumnCollation(tableName, "c1", "UNICODE") + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING COLLATE UNICODE) USING $dataSource") + assertTableColumnCollation(testTable, "c1", "UNICODE") } } test("create table as select") { - val tableName = "tbl" - // literals in select do not pick up session collation - withSessionCollationAndTable("UTF8_LCASE", tableName) { + withSessionCollationAndTable("UTF8_LCASE", testTable) { sql(s""" - |CREATE TABLE $tableName USING $dataSource AS SELECT + |CREATE TABLE $testTable USING $dataSource AS SELECT | 'a' AS c1, | 'a' || 'a' AS c2, | SUBSTRING('a', 1, 1) AS c3, | SUBSTRING(SUBSTRING('ab', 1, 1), 1, 1) AS c4, | 'a' = 'A' AS truthy |""".stripMargin) - assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") - assertTableColumnCollation(tableName, "c2", "UTF8_BINARY") - assertTableColumnCollation(tableName, "c3", "UTF8_BINARY") - assertTableColumnCollation(tableName, "c4", "UTF8_BINARY") + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + assertTableColumnCollation(testTable, "c2", "UTF8_BINARY") + assertTableColumnCollation(testTable, "c3", "UTF8_BINARY") + assertTableColumnCollation(testTable, "c4", "UTF8_BINARY") - checkAnswer(sql(s"SELECT COUNT(*) FROM $tableName WHERE truthy"), Seq(Row(0))) + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE truthy"), Seq(Row(0))) } // literals in inline table do not pick up session collation - withSessionCollationAndTable("UTF8_LCASE", tableName) { + withSessionCollationAndTable("UTF8_LCASE", testTable) { sql(s""" - |CREATE TABLE $tableName USING $dataSource AS + |CREATE TABLE $testTable USING $dataSource AS |SELECT c1, c1 = 'A' as c2 FROM VALUES ('a'), ('A') AS vals(c1) |""".stripMargin) - assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") - checkAnswer(sql(s"SELECT COUNT(*) FROM $tableName WHERE c2"), Seq(Row(1))) + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(1))) } // cast in select does not pick up session collation - withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName USING $dataSource AS SELECT cast('a' AS STRING) AS c1") - assertTableColumnCollation(tableName, "c1", "UTF8_BINARY") + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable USING $dataSource AS SELECT cast('a' AS STRING) AS c1") + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") } } test("ctas with complex types") { - val tableName = "tbl_complex" - withSessionCollationAndTable("UTF8_LCASE", tableName) { + withSessionCollationAndTable("UTF8_LCASE", testTable) { sql(s""" - |CREATE TABLE $tableName USING $dataSource AS + |CREATE TABLE $testTable USING $dataSource AS |SELECT | struct('a') AS c1, | map('a', 'b') AS c2, | array('a') AS c3 |""".stripMargin) - checkAnswer(sql(s"SELECT COLLATION(c1.col1) FROM $tableName"), Seq(Row("UTF8_BINARY"))) - checkAnswer(sql(s"SELECT COLLATION(c2['a']) FROM $tableName"), Seq(Row("UTF8_BINARY"))) - checkAnswer(sql(s"SELECT COLLATION(c3[0]) FROM $tableName"), Seq(Row("UTF8_BINARY"))) + checkAnswer(sql(s"SELECT COLLATION(c1.col1) FROM $testTable"), Seq(Row("UTF8_BINARY"))) + checkAnswer(sql(s"SELECT COLLATION(c2['a']) FROM $testTable"), Seq(Row("UTF8_BINARY"))) + checkAnswer(sql(s"SELECT COLLATION(c3[0]) FROM $testTable"), Seq(Row("UTF8_BINARY"))) } } test("ctas with union") { - val tableName = "tbl_union" - withSessionCollationAndTable("UTF8_LCASE", tableName) { + withSessionCollationAndTable("UTF8_LCASE", testTable) { sql(s""" - |CREATE TABLE $tableName USING $dataSource AS + |CREATE TABLE $testTable USING $dataSource AS |SELECT 'a' = 'A' AS c1 |UNION |SELECT 'b' = 'B' AS c1 |""".stripMargin) - checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row(false))) + checkAnswer(sql(s"SELECT * FROM $testTable"), Seq(Row(false))) } - withSessionCollationAndTable("UTF8_LCASE", tableName) { + withSessionCollationAndTable("UTF8_LCASE", testTable) { sql(s""" - |CREATE TABLE $tableName USING $dataSource AS + |CREATE TABLE $testTable USING $dataSource AS |SELECT 'a' = 'A' AS c1 |UNION ALL |SELECT 'b' = 'B' AS c1 |""".stripMargin) - checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row(false), Row(false))) + checkAnswer(sql(s"SELECT * FROM $testTable"), Seq(Row(false), Row(false))) } } test("add column") { - val tableName = "tbl_add_col" - withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") - assertTableColumnCollation(tableName, "c1", "UTF8_LCASE") - - sql(s"ALTER TABLE $tableName ADD COLUMN c2 STRING") - assertTableColumnCollation(tableName, "c2", "UTF8_BINARY") - - sql(s"ALTER TABLE $tableName ADD COLUMN c3 STRING COLLATE UNICODE") - assertTableColumnCollation(tableName, "c3", "UNICODE") - } - } - - test("create/alter view created from a table") { - val tableName = "tbl_view" - val viewName = "view2" - val sessionCollation = "UTF8_LCASE" - withSessionCollationAndTable(sessionCollation, tableName) { - sql(s"CREATE TABLE $tableName (c1 STRING, c2 STRING COLLATE UNICODE_CI) USING $dataSource") - sql(s"INSERT INTO $tableName VALUES ('a', 'a'), ('A', 'A')") - - withView(viewName) { - sql(s"CREATE VIEW $viewName AS SELECT * FROM $tableName") - - assertTableColumnCollation(viewName, "c1", "UTF8_BINARY") - assertTableColumnCollation(viewName, "c2", "UNICODE_CI") - checkAnswer( - sql(s"SELECT DISTINCT COLLATION(c1), COLLATION('a') FROM $viewName"), - Row("UTF8_BINARY", sessionCollation)) - - // filter should use session collation - checkAnswer(sql(s"SELECT COUNT(*) FROM $viewName WHERE 'a' = 'A'"), Row(2)) - - // filter should use column collation - checkAnswer(sql(s"SELECT COUNT(*) FROM $viewName WHERE c1 = 'A'"), Row(1)) - - // literal with explicit collation wins - checkAnswer( - sql(s"SELECT COUNT(*) FROM $viewName WHERE c1 = 'A' collate UNICODE_CI"), - Row(2)) - - // two implicit collations -> errors out - assertThrowsImplicitMismatch(sql(s"SELECT c1 = substring('A', 0, 1) FROM $viewName")) - - sql(s"ALTER VIEW $viewName AS SELECT c1 COLLATE UNICODE_CI AS c1, c2 FROM $tableName") - assertTableColumnCollation(viewName, "c1", "UNICODE_CI") - assertTableColumnCollation(viewName, "c2", "UNICODE_CI") - checkAnswer( - sql(s"SELECT DISTINCT COLLATION(c1), COLLATION('a') FROM $viewName"), - Row("UNICODE_CI", sessionCollation)) + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") + assertTableColumnCollation(testTable, "c1", "UTF8_LCASE") - // after alter both rows should be returned - checkAnswer(sql(s"SELECT COUNT(*) FROM $viewName WHERE c1 = 'A'"), Row(2)) - } - } - } + sql(s"ALTER TABLE $testTable ADD COLUMN c2 STRING") + assertTableColumnCollation(testTable, "c2", "UTF8_BINARY") - test("join view with table") { - val viewTableName = "view_table" - val joinTableName = "join_table" - val viewName = "view" - val sessionCollation = "sr" - - withSessionCollationAndTable(sessionCollation, viewTableName, joinTableName) { - sql(s"CREATE TABLE $viewTableName (c1 STRING COLLATE UNICODE_CI) USING $dataSource") - sql(s"CREATE TABLE $joinTableName (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") - sql(s"INSERT INTO $viewTableName VALUES ('a')") - sql(s"INSERT INTO $joinTableName VALUES ('A')") - - withView(viewName) { - sql(s"CREATE VIEW $viewName AS SELECT * FROM $viewTableName") - - assertThrowsImplicitMismatch( - sql(s"SELECT * FROM $viewName JOIN $joinTableName ON $viewName.c1 = $joinTableName.c1")) - - checkAnswer( - sql(s""" - |SELECT COLLATION($viewName.c1), COLLATION($joinTableName.c1) - |FROM $viewName JOIN $joinTableName - |ON $viewName.c1 = $joinTableName.c1 COLLATE UNICODE_CI - |""".stripMargin), - Row("UNICODE_CI", "UTF8_LCASE")) - } + sql(s"ALTER TABLE $testTable ADD COLUMN c3 STRING COLLATE UNICODE") + assertTableColumnCollation(testTable, "c3", "UNICODE") } } @@ -332,33 +255,31 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { } test("having group by is aware of session collation") { - val tableName = "tbl_grp_by" - withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName (c1 STRING) USING $dataSource") - sql(s"INSERT INTO $tableName VALUES ('a'), ('A')") + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('a'), ('A')") // having clause uses session (default) collation checkAnswer( - sql(s"SELECT COUNT(*) FROM $tableName GROUP BY c1 HAVING 'a' = 'A'"), + sql(s"SELECT COUNT(*) FROM $testTable GROUP BY c1 HAVING 'a' = 'A'"), Seq(Row(1), Row(1))) // having clause uses column (implicit) collation checkAnswer( - sql(s"SELECT COUNT(*) FROM $tableName GROUP BY c1 HAVING c1 = 'A'"), + sql(s"SELECT COUNT(*) FROM $testTable GROUP BY c1 HAVING c1 = 'A'"), Seq(Row(1))) } } test("min/max are aware of session collation") { // scalastyle:off nonascii - val tableName = "tbl_min_max" - withSessionCollationAndTable("UNICODE", tableName) { - sql(s"CREATE TABLE $tableName (c1 STRING) USING $dataSource") - sql(s"INSERT INTO $tableName VALUES ('1'), ('½')") + withSessionCollationAndTable("UNICODE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('1'), ('½')") - checkAnswer(sql(s"SELECT MIN(c1) FROM $tableName"), Seq(Row("1"))) + checkAnswer(sql(s"SELECT MIN(c1) FROM $testTable"), Seq(Row("1"))) - checkAnswer(sql(s"SELECT MAX(c1) FROM $tableName"), Seq(Row("½"))) + checkAnswer(sql(s"SELECT MAX(c1) FROM $testTable"), Seq(Row("½"))) } // scalastyle:on nonascii } @@ -384,27 +305,105 @@ class DefaultCollationTestSuite extends DatasourceV2SQLBase { } test("literals in insert inherit session level collation") { - val tableName = "tbl_insert" - withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName (c1 BOOLEAN) USING $dataSource") - sql(s"INSERT INTO $tableName VALUES ('a' = 'A')") + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 BOOLEAN) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('a' = 'A')") + + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c1"), Seq(Row(1))) + + sql(s"INSERT INTO $testTable VALUES (array_contains(array('a'), 'A'))") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c1"), Seq(Row(2))) + } + } + + // endregion +} + +class DefaultCollationTestSuiteV1 extends DefaultCollationTestSuite { + + test("create/alter view created from a table") { + val sessionCollation = "UTF8_LCASE" + withSessionCollationAndTable(sessionCollation, testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING, c2 STRING COLLATE UNICODE_CI) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('a', 'a'), ('A', 'A')") + + withView(testView) { + sql(s"CREATE VIEW $testView AS SELECT * FROM $testTable") + + assertTableColumnCollation(testView, "c1", "UTF8_BINARY") + assertTableColumnCollation(testView, "c2", "UNICODE_CI") + checkAnswer( + sql(s"SELECT DISTINCT COLLATION(c1), COLLATION('a') FROM $testView"), + Row("UTF8_BINARY", sessionCollation)) + + // filter should use session collation + checkAnswer(sql(s"SELECT COUNT(*) FROM $testView WHERE 'a' = 'A'"), Row(2)) + + // filter should use column collation + checkAnswer(sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = 'A'"), Row(1)) + + // literal with explicit collation wins + checkAnswer( + sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = 'A' collate UNICODE_CI"), + Row(2)) + + // two implicit collations -> errors out + assertThrowsImplicitMismatch(sql(s"SELECT c1 = substring('A', 0, 1) FROM $testView")) + + sql(s"ALTER VIEW $testView AS SELECT c1 COLLATE UNICODE_CI AS c1, c2 FROM $testTable") + assertTableColumnCollation(testView, "c1", "UNICODE_CI") + assertTableColumnCollation(testView, "c2", "UNICODE_CI") + checkAnswer( + sql(s"SELECT DISTINCT COLLATION(c1), COLLATION('a') FROM $testView"), + Row("UNICODE_CI", sessionCollation)) - checkAnswer(sql(s"SELECT COUNT(*) FROM $tableName WHERE c1"), Seq(Row(1))) + // after alter both rows should be returned + checkAnswer(sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = 'A'"), Row(2)) + } + } + } - sql(s"INSERT INTO $tableName VALUES (array_contains(array('a'), 'A'))") - checkAnswer(sql(s"SELECT COUNT(*) FROM $tableName WHERE c1"), Seq(Row(2))) + test("join view with table") { + val viewTableName = "view_table" + val joinTableName = "join_table" + val sessionCollation = "sr" + + withSessionCollationAndTable(sessionCollation, viewTableName, joinTableName) { + sql(s"CREATE TABLE $viewTableName (c1 STRING COLLATE UNICODE_CI) USING $dataSource") + sql(s"CREATE TABLE $joinTableName (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") + sql(s"INSERT INTO $viewTableName VALUES ('a')") + sql(s"INSERT INTO $joinTableName VALUES ('A')") + + withView(testView) { + sql(s"CREATE VIEW $testView AS SELECT * FROM $viewTableName") + + assertThrowsImplicitMismatch( + sql(s"SELECT * FROM $testView JOIN $joinTableName ON $testView.c1 = $joinTableName.c1")) + + checkAnswer( + sql(s""" + |SELECT COLLATION($testView.c1), COLLATION($joinTableName.c1) + |FROM $testView JOIN $joinTableName + |ON $testView.c1 = $joinTableName.c1 COLLATE UNICODE_CI + |""".stripMargin), + Row("UNICODE_CI", "UTF8_LCASE")) + } } } +} + +class DefaultCollationTestSuiteV2 extends DefaultCollationTestSuite with DatasourceV2SQLBase { + override def testTable: String = s"testcat.${super.testTable}" + override def testView: String = s"testcat.${super.testView}" + // delete only works on v2 test("delete behavior") { - val tableName = "testcat.tbl_delete" - withSessionCollationAndTable("UTF8_LCASE", tableName) { - sql(s"CREATE TABLE $tableName (c1 STRING) USING $dataSource") - sql(s"INSERT INTO $tableName VALUES ('a'), ('A')") + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('a'), ('A')") - sql(s"DELETE FROM $tableName WHERE 'a' = 'A'") - checkAnswer(sql(s"SELECT COUNT(*) FROM $tableName"), Seq(Row(0))) + sql(s"DELETE FROM $testTable WHERE 'a' = 'A'") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable"), Seq(Row(0))) } } - // endregion } From d0ad67368131049d608316ca681aec3620918155 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 14 Nov 2024 17:54:47 +0100 Subject: [PATCH 27/54] add more comments --- .../spark/sql/catalyst/analysis/Analyzer.scala | 7 +++++-- .../catalyst/analysis/CollationTypeCoercion.scala | 3 +-- .../catalyst/analysis/ResolveDefaultStringType.scala | 12 ++++++++---- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1fc667265b606..b0ac6b987f96e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -316,8 +316,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveProcedures :: BindProcedures :: ResolveTableSpec :: - new ResolveDefaultStringType(replaceWithTemp = true) :: - new ResolveDefaultStringType(replaceWithTemp = false) :: + // Due to how tree transformations work and StringType object being equal to + // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice + // to ensure the correct results for occurrences of default string type. + new ResolveDefaultStringType(replaceWithTempType = true) :: + new ResolveDefaultStringType(replaceWithTempType = false) :: ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index 1e9c3aabedb33..225586b05c009 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -83,7 +83,6 @@ import org.apache.spark.sql.catalyst.expressions.{ TryToNumber } import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType} /** @@ -273,7 +272,7 @@ object CollationTypeCoercion { ) } else { - implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType) + implicitTypes.headOption.map(StringType(_)).getOrElse(StringType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index a5a75fb2118c6..6549738a4b0ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -33,8 +33,12 @@ case class TemporaryStringType( * determined by the session's default string type. For DDL, the default string type is the * default type of the object (table -> schema -> catalog). However, this is not implemented yet. * So, we will just use UTF8_BINARY for now. + * + * `replaceWithTempType` is a flag that determines whether to replace the default string type with a + * [[TemporaryStringType]] object in cases where the old type and new are equal and thus would + * not change the plan after transformation. */ -class ResolveDefaultStringType(replaceWithTemp: Boolean) extends Rule[LogicalPlan] { +class ResolveDefaultStringType(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { if (isDefaultSessionCollationUsed) { plan @@ -52,7 +56,7 @@ class ResolveDefaultStringType(replaceWithTemp: Boolean) extends Rule[LogicalPla * UTF8_BINARY). */ private def stringTypeForDDLCommand(table: LogicalPlan): StringType = - new StringType(0) + StringType(0) private def isDDLCommand(plan: LogicalPlan): Boolean = plan match { case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true @@ -121,7 +125,7 @@ class ResolveDefaultStringType(replaceWithTemp: Boolean) extends Rule[LogicalPla private def isDefaultStringType(dataType: DataType): Boolean = { dataType match { case _: TemporaryStringType => - !replaceWithTemp + !replaceWithTempType case st: StringType => // should only return true for StringType object and not StringType("UTF8_BINARY") st.eq(StringType) @@ -132,7 +136,7 @@ class ResolveDefaultStringType(replaceWithTemp: Boolean) extends Rule[LogicalPla private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = { dataType.transformRecursively { case currentType if isDefaultStringType(currentType) => - if (replaceWithTemp && currentType == newType) TemporaryStringType() else newType + if (replaceWithTempType && currentType == newType) TemporaryStringType() else newType } } From 467790b31d8f0eec6c0610c3ac34cd03fc608258 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 19 Nov 2024 17:09:24 +0100 Subject: [PATCH 28/54] fix failing tests --- .../spark/sql/catalyst/expressions/misc.scala | 4 +- .../sql/CollationSQLExpressionsSuite.scala | 254 +----------------- .../org/apache/spark/sql/CollationSuite.scala | 20 -- .../collation/DefaultCollationTestSuite.scala | 6 +- 4 files changed, 18 insertions(+), 266 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 5f1b3dc0a01ac..622a0e0aa5bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCollation +import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCollation} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: override def foldable: Boolean = false override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, MapType(StringType, StringType)) + Seq(StringTypeWithCollation, AbstractMapType(StringTypeWithCollation, StringTypeWithCollation)) override def left: Expression = errorClass override def right: Expression = errorParms diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 3563e04dced11..22dabdb6ee1e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -18,11 +18,8 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} -import java.text.SimpleDateFormat -import scala.collection.immutable.Seq - -import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException} +import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException, SparkThrowable} import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Mode @@ -355,54 +352,6 @@ class CollationSQLExpressionsSuite }) } - test("Support CsvToStructs csv expression with collation") { - case class CsvToStructsTestCase( - input: String, - collationName: String, - schema: String, - options: String, - result: Row, - structFields: Seq[StructField] - ) - - val testCases = Seq( - CsvToStructsTestCase("1", "UTF8_BINARY", "'a INT'", "", - Row(1), Seq( - StructField("a", IntegerType, nullable = true) - )), - CsvToStructsTestCase("true, 0.8", "UTF8_LCASE", "'A BOOLEAN, B DOUBLE'", "", - Row(true, 0.8), Seq( - StructField("A", BooleanType, nullable = true), - StructField("B", DoubleType, nullable = true) - )), - CsvToStructsTestCase("\"Spark\"", "UNICODE", "'a STRING'", "", - Row("Spark"), Seq( - StructField("a", StringType("UNICODE"), nullable = true) - )), - CsvToStructsTestCase("26/08/2015", "UTF8_BINARY", "'time Timestamp'", - ", map('timestampFormat', 'dd/MM/yyyy')", Row( - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse("2015-08-26 00:00:00.0") - ), Seq( - StructField("time", TimestampType, nullable = true) - )) - ) - - // Supported collations - testCases.foreach(t => { - val query = - s""" - |select from_csv('${t.input}', ${t.schema} ${t.options}) - |""".stripMargin - // Result - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { - val testQuery = sql(query) - val queryResult = testQuery.collect().head - checkAnswer(testQuery, Row(t.result)) - val dataType = StructType(t.structFields) - assert(testQuery.schema.fields.head.dataType.sameType(dataType)) - } - }) - } test("Support SchemaOfCsv csv expression with collation") { case class SchemaOfCsvTestCase( @@ -1082,7 +1031,7 @@ class CollationSQLExpressionsSuite val query = s"SELECT raise_error('${t.errorMessage}')" // Result & data type checkError( - exception = intercept[SparkRuntimeException] { + exception = intercept[SparkThrowable] { sql(query).collect() }, condition = "USER_RAISED_EXCEPTION", @@ -1270,86 +1219,6 @@ class CollationSQLExpressionsSuite ) } - test("Support XmlToStructs xml expression with collation") { - case class XmlToStructsTestCase( - input: String, - collationName: String, - schema: String, - options: String, - result: Row, - structFields: Seq[StructField] - ) - - val testCases = Seq( - XmlToStructsTestCase("

1

", "UTF8_BINARY", "'a INT'", "", - Row(1), Seq( - StructField("a", IntegerType, nullable = true) - )), - XmlToStructsTestCase("

true0.8

", "UTF8_LCASE", - "'A BOOLEAN, B DOUBLE'", "", Row(true, 0.8), Seq( - StructField("A", BooleanType, nullable = true), - StructField("B", DoubleType, nullable = true) - )), - XmlToStructsTestCase("

Spark

", "UNICODE", "'s STRING'", "", - Row("Spark"), Seq( - StructField("s", StringType("UNICODE"), nullable = true) - )), - XmlToStructsTestCase("

", "UNICODE_CI", "'time Timestamp'", - ", map('timestampFormat', 'dd/MM/yyyy')", Row( - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse("2015-08-26 00:00:00.0") - ), Seq( - StructField("time", TimestampType, nullable = true) - )) - ) - - // Supported collations - testCases.foreach(t => { - val query = - s""" - |select from_xml('${t.input}', ${t.schema} ${t.options}) - |""".stripMargin - // Result - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { - val testQuery = sql(query) - checkAnswer(testQuery, Row(t.result)) - val dataType = StructType(t.structFields) - assert(testQuery.schema.fields.head.dataType.sameType(dataType)) - } - }) - } - - test("Support SchemaOfXml xml expression with collation") { - case class SchemaOfXmlTestCase( - input: String, - collationName: String, - result: String - ) - - val testCases = Seq( - SchemaOfXmlTestCase("

1

", "UTF8_BINARY", "STRUCT"), - SchemaOfXmlTestCase("

true0.8

", "UTF8_LCASE", - "STRUCT"), - SchemaOfXmlTestCase("

", "UNICODE", "STRUCT<>"), - SchemaOfXmlTestCase("

123

", "UNICODE_CI", - "STRUCT>") - ) - - // Supported collations - testCases.foreach(t => { - val query = - s""" - |select schema_of_xml('${t.input}') - |""".stripMargin - // Result - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { - val testQuery = sql(query) - checkAnswer(testQuery, Row(t.result)) - val dataType = StringType(t.collationName) - assert(testQuery.schema.fields.head.dataType.sameType(dataType)) - } - }) - } - test("Support StructsToXml xml expression with collation") { case class StructsToXmlTestCase( input: String, @@ -1502,55 +1371,6 @@ class CollationSQLExpressionsSuite }) } - test("Support VariantGet & TryVariantGet variant expressions with collation") { - case class VariantGetTestCase( - input: String, - path: String, - variantType: String, - collationName: String, - result: Any, - resultType: DataType - ) - - val testCases = Seq( - VariantGetTestCase("{\"a\": 1}", "$.a", "int", "UTF8_BINARY", 1, IntegerType), - VariantGetTestCase("{\"a\": 1}", "$.b", "int", "UTF8_LCASE", null, IntegerType), - VariantGetTestCase("[1, \"2\"]", "$[1]", "string", "UNICODE", "2", StringType("UNICODE")), - VariantGetTestCase("[1, \"2\"]", "$[2]", "string", "UNICODE_CI", null, - StringType("UNICODE_CI")) - ) - - // Supported collations (VariantGet) - testCases.foreach(t => { - val query = - s""" - |SELECT variant_get(parse_json('${t.input}'), '${t.path}', '${t.variantType}') - |""".stripMargin - // Result & data type - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { - val testQuery = sql(query) - val testResult = testQuery.collect().map(_.toString()).mkString("") - assert(testResult === "[" + t.result + "]") // can't use checkAnswer for Variant - assert(testQuery.schema.fields.head.dataType.sameType(t.resultType)) - } - }) - - // Supported collations (TryVariantGet) - testCases.foreach(t => { - val query = - s""" - |SELECT try_variant_get(parse_json('${t.input}'), '${t.path}', '${t.variantType}') - |""".stripMargin - // Result & data type - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { - val testQuery = sql(query) - val testResult = testQuery.collect().map(_.toString()).mkString("") - assert(testResult === "[" + t.result + "]") // can't use checkAnswer for Variant - assert(testQuery.schema.fields.head.dataType.sameType(t.resultType)) - } - }) - } - test("Handle invalid JSON for VariantGet variant expression with collation") { // variant_get should throw an exception if the cast fails val json = "[1, \"Spark\"]" @@ -1578,65 +1398,6 @@ class CollationSQLExpressionsSuite } } - test("Support VariantExplode variant expressions with collation") { - case class VariantExplodeTestCase( - input: String, - collationName: String, - result: String, - resultType: Seq[StructField] - ) - - val testCases = Seq( - VariantExplodeTestCase("[\"hello\", \"world\"]", "UTF8_BINARY", - Row(0, "null", "\"hello\"").toString() + Row(1, "null", "\"world\"").toString(), - Seq[StructField]( - StructField("pos", IntegerType, nullable = false), - StructField("key", StringType("UTF8_BINARY")), - StructField("value", VariantType, nullable = false) - ) - ), - VariantExplodeTestCase("[\"Spark\", \"SQL\"]", "UTF8_LCASE", - Row(0, "null", "\"Spark\"").toString() + Row(1, "null", "\"SQL\"").toString(), - Seq[StructField]( - StructField("pos", IntegerType, nullable = false), - StructField("key", StringType("UTF8_LCASE")), - StructField("value", VariantType, nullable = false) - ) - ), - VariantExplodeTestCase("{\"a\": true, \"b\": 3.14}", "UNICODE", - Row(0, "a", "true").toString() + Row(1, "b", "3.14").toString(), - Seq[StructField]( - StructField("pos", IntegerType, nullable = false), - StructField("key", StringType("UNICODE")), - StructField("value", VariantType, nullable = false) - ) - ), - VariantExplodeTestCase("{\"A\": 9.99, \"B\": false}", "UNICODE_CI", - Row(0, "A", "9.99").toString() + Row(1, "B", "false").toString(), - Seq[StructField]( - StructField("pos", IntegerType, nullable = false), - StructField("key", StringType("UNICODE_CI")), - StructField("value", VariantType, nullable = false) - ) - ) - ) - - // Supported collations - testCases.foreach(t => { - val query = - s""" - |SELECT * from variant_explode(parse_json('${t.input}')) - |""".stripMargin - // Result & data type - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { - val testQuery = sql(query) - val testResult = testQuery.collect().map(_.toString()).mkString("") - assert(testResult === t.result) // can't use checkAnswer for Variant - assert(testQuery.schema.fields.sameElements(t.resultType)) - } - }) - } - test("Support SchemaOfVariant variant expressions with collation") { case class SchemaOfVariantTestCase( input: String, @@ -2416,8 +2177,15 @@ class CollationSQLExpressionsSuite |collate('${testCase.left}', '${testCase.leftCollation}'))= |collate('${testCase.right}', '${testCase.rightCollation}'); |""".stripMargin - val testQuery = sql(query) - checkAnswer(testQuery, Row(testCase.result)) + + if (testCase.leftCollation == testCase.rightCollation) { + checkAnswer(sql(query), Row(testCase.result)) + } else { + val exception = intercept[AnalysisException] { + sql(query) + } + assert(exception.getCondition === "COLLATION_MISMATCH.EXPLICIT") + } }) val queryPass = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 9b0dd9b984227..d6de5003b3c5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -851,26 +851,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("SPARK-47692: Parameter markers with variable mapping") { - checkAnswer( - spark.sql( - "SELECT collation(:var1 || :var2)", - Map("var1" -> Literal.create('a', StringType("UTF8_BINARY")), - "var2" -> Literal.create('b', StringType("UNICODE")))), - Seq(Row("UTF8_BINARY")) - ) - - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { - checkAnswer( - spark.sql( - "SELECT collation(:var1 || :var2)", - Map("var1" -> Literal.create('a', StringType("UTF8_BINARY")), - "var2" -> Literal.create('b', StringType("UNICODE")))), - Seq(Row("UNICODE")) - ) - } - } - test("SPARK-47210: Cast of default collated strings in IN expression") { val tableName = "t1" withTable(tableName) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index 7eafbc580e972..af9462a4bf46d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -342,13 +342,17 @@ class DefaultCollationTestSuiteV1 extends DefaultCollationTestSuite { // filter should use column collation checkAnswer(sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = 'A'"), Row(1)) + checkAnswer( + sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = substring('A', 0, 1)"), + Row(1)) + // literal with explicit collation wins checkAnswer( sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = 'A' collate UNICODE_CI"), Row(2)) // two implicit collations -> errors out - assertThrowsImplicitMismatch(sql(s"SELECT c1 = substring('A', 0, 1) FROM $testView")) + assertThrowsImplicitMismatch(sql(s"SELECT c1 = c2 FROM $testView")) sql(s"ALTER VIEW $testView AS SELECT c1 COLLATE UNICODE_CI AS c1, c2 FROM $testTable") assertTableColumnCollation(testView, "c1", "UNICODE_CI") From 134458096edc7c7f916f58c82b20f77e926cf04f Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 20 Nov 2024 19:11:55 +0100 Subject: [PATCH 29/54] add v1 api trait back --- .../internal/types/AbstractStringType.scala | 3 +- .../analysis/CollationTypeCoercion.scala | 321 ++++++++++-------- .../analysis/ResolveDefaultStringType.scala | 4 +- .../analysis/TypeCoercionHelper.scala | 4 +- .../expressions/complexTypeExtractors.scala | 1 + .../catalyst/plans/logical/v2Commands.scala | 2 + .../spark/sql/execution/datasources/ddl.scala | 4 +- .../collation/DefaultCollationTestSuite.scala | 26 ++ 8 files changed, 222 insertions(+), 143 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 49d8bf9e001ab..6dcb8a876b7a2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.internal.types -import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** @@ -26,7 +25,7 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} abstract class AbstractStringType(supportsTrimCollation: Boolean = false) extends AbstractDataType with Serializable { - override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType + override private[sql] def defaultConcreteType: DataType = StringType override private[sql] def simpleString: String = "string" override private[sql] def acceptsType(other: DataType): Boolean = other match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index 225586b05c009..d598b754691da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -17,94 +17,51 @@ package org.apache.spark.sql.catalyst.analysis -import javax.annotation.Nullable - import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} -import org.apache.spark.sql.catalyst.expressions.{ - ArrayAppend, - ArrayContains, - ArrayExcept, - ArrayIntersect, - ArrayJoin, - ArrayPosition, - ArrayRemove, - ArraysOverlap, - ArrayUnion, - CaseWhen, - Cast, - Coalesce, - Collate, - Concat, - ConcatWs, - Contains, - CreateArray, - CreateMap, - Elt, - EndsWith, - EqualNullSafe, - EqualTo, - Expression, - FindInSet, - GetMapValue, - GreaterThan, - GreaterThanOrEqual, - Greatest, - If, - In, - InSubquery, - Lag, - Lead, - Least, - LessThan, - LessThanOrEqual, - Levenshtein, - Literal, - Mask, - Overlay, - RaiseError, - RegExpReplace, - SplitPart, - StartsWith, - StringInstr, - StringLocate, - StringLPad, - StringReplace, - StringRPad, - StringSplitSQL, - StringToMap, - StringTranslate, - StringTrim, - StringTrimLeft, - StringTrimRight, - SubstringIndex, - ToNumber, - TryToNumber -} +import org.apache.spark.sql.catalyst.analysis.CollationStrength.{Default, Explicit, Implicit} +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.haveSameType +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType} +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} +import org.apache.spark.sql.util.SchemaUtils /** * Type coercion helper that matches against expressions in order to apply collation type coercion. */ object CollationTypeCoercion { + + private val COLLATION_CONTEXT_TAG = new TreeNodeTag[CollationContext]("collationContext") + + private def hasCollationContextTag(expr: Expression): Boolean = { + expr.getTagValue(COLLATION_CONTEXT_TAG).isDefined + } + def apply(expression: Expression): Expression = expression match { + case cast: Cast if shouldRemoveCast(cast) => + cast.child + case ifExpr: If => ifExpr.withNewChildren( ifExpr.predicate +: collateToSingleType(Seq(ifExpr.trueValue, ifExpr.falseValue)) ) case caseWhenExpr: CaseWhen if !haveSameType(caseWhenExpr.inputTypesForMerging) => - val outputStringType = - getOutputCollation(caseWhenExpr.branches.map(_._2) ++ caseWhenExpr.elseValue) - val newBranches = caseWhenExpr.branches.map { - case (condition, value) => - (condition, castStringType(value, outputStringType).getOrElse(value)) + val outputStringType = findLeastCommonStringType( + caseWhenExpr.branches.map(_._2) ++ caseWhenExpr.elseValue) + outputStringType match { + case Some(st) => + val newBranches = caseWhenExpr.branches.map { case (condition, value) => + (condition, castStringType(value, st)) + } + val newElseValue = + caseWhenExpr.elseValue.map(e => castStringType(e, st)) + CaseWhen(newBranches, newElseValue) + + case _ => + caseWhenExpr } - val newElseValue = - caseWhenExpr.elseValue.map(e => castStringType(e, outputStringType).getOrElse(e)) - CaseWhen(newBranches, newElseValue) case stringLocate: StringLocate => stringLocate.withNewChildren( @@ -136,13 +93,6 @@ object CollationTypeCoercion { val Seq(newStr, newPad) = collateToSingleType(Seq(str, pad)) stringPadExpr.withNewChildren(Seq(newStr, len, newPad)) - case raiseError: RaiseError => - val newErrorParams = raiseError.errorParms.dataType match { - case MapType(StringType, StringType, _) => raiseError.errorParms - case _ => Cast(raiseError.errorParms, MapType(StringType, StringType)) - } - raiseError.withNewChildren(Seq(raiseError.errorClass, newErrorParams)) - case framelessOffsetWindow @ (_: Lag | _: Lead) => val Seq(input, offset, default) = framelessOffsetWindow.children val Seq(newInput, newDefault) = collateToSingleType(Seq(input, default)) @@ -192,88 +142,189 @@ object CollationTypeCoercion { case other => other } + /** + * If childType is collated and target is UTF8_BINARY, the collation of the output + * should be that of the childType. + */ + private def shouldRemoveCast(cast: Cast): Boolean = { + val isUserDefined = cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined + val isChildTypeCollatedString = cast.child.dataType match { + case st: StringType => !st.isUTF8BinaryCollation + case _ => false + } + val targetType = cast.dataType + + isUserDefined && isChildTypeCollatedString && targetType == StringType + } + /** * Extracts StringTypes from filtered hasStringType */ @tailrec - private def extractStringType(dt: DataType): StringType = dt match { - case st: StringType => st + private def extractStringType(dt: DataType): Option[StringType] = dt match { + case st: StringType => Some(st) case ArrayType(et, _) => extractStringType(et) + case _ => None } /** * Casts given expression to collated StringType with id equal to collationId only * if expression has StringType in the first place. - * @param expr - * @param collationId - * @return */ - def castStringType(expr: Expression, st: StringType): Option[Expression] = - castStringType(expr.dataType, st).map { dt => Cast(expr, dt)} + def castStringType(expr: Expression, st: StringType): Expression = { + castStringType(expr.dataType, st) + .map(dt => Cast(expr, dt)) + .getOrElse(expr) + } private def castStringType(inType: DataType, castType: StringType): Option[DataType] = { - @Nullable val ret: DataType = inType match { - case st: StringType if st.collationId != castType.collationId => castType + inType match { + case st: StringType if st.collationId != castType.collationId => + Some(castType) case ArrayType(arrType, nullable) => - castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull - case _ => null + castStringType(arrType, castType).map(ArrayType(_, nullable)) + case _ => None } - Option(ret) } /** * Collates input expressions to a single collation. */ - def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = { - val st = getOutputCollation(exprs) + def collateToSingleType(expressions: Seq[Expression]): Seq[Expression] = { + val lctOpt = findLeastCommonStringType(expressions) - exprs.map(e => castStringType(e, st).getOrElse(e)) + lctOpt match { + case Some(lct) => + expressions.map(e => castStringType(e, lct)) + case _ => + expressions + } } /** - * Based on the data types of the input expressions this method determines - * a collation type which the output will have. This function accepts Seq of - * any expressions, but will only be affected by collated StringTypes or - * complex DataTypes with collated StringTypes (e.g. ArrayType) + * Tries to find the least common StringType among the given expressions. */ - def getOutputCollation(expr: Seq[Expression]): StringType = { - val explicitTypes = expr.filter { - case _: Collate => true - case _ => false - } - .map(_.dataType.asInstanceOf[StringType].collationId) - .distinct - - explicitTypes.size match { - // We have 1 explicit collation - case 1 => StringType(explicitTypes.head) - // Multiple explicit collations occurred - case size if size > 1 => - throw QueryCompilationErrors - .explicitCollationMismatchError( - explicitTypes.map(t => StringType(t)) - ) - // Only implicit or default collations present - case 0 => - val implicitTypes = expr.filter { - case Literal(_, _: StringType) => false - case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty => - cast.child.dataType.isInstanceOf[StringType] - case _ => true - } - .map(_.dataType) - .filter(hasStringType) - .map(extractStringType(_).collationId) - .distinct - - if (implicitTypes.length > 1) { - throw QueryCompilationErrors.implicitCollationMismatchError( - implicitTypes.map(t => StringType(t)) - ) - } - else { - implicitTypes.headOption.map(StringType(_)).getOrElse(StringType) + private def findLeastCommonStringType(expressions: Seq[Expression]): Option[StringType] = { + if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) { + return None + } + + val collationContextWinner = expressions.foldLeft(findCollationContext(expressions.head)) { + case (Some(left), right) => + findCollationContext(right).flatMap { ctx => + collationPrecedenceWinner(left, ctx) } + case (None, _) => return None + } + + collationContextWinner.flatMap { cc => + extractStringType(cc.dataType) + } + } + + /** + * Tries to find the collation context for the given expression. + * If found, it will also set the [[COLLATION_CONTEXT_TAG]] on the expression, + * so that the context can be reused later. + */ + private def findCollationContext(expr: Expression): Option[CollationContext] = { + val contextOpt = expr match { + case _ if hasCollationContextTag(expr) => + Some(expr.getTagValue(COLLATION_CONTEXT_TAG).get) + + case _ if !expr.dataType.existsRecursively(_.isInstanceOf[StringType]) => + None + + case collate: Collate => + Some(CollationContext(collate.dataType, Explicit)) + + case _: Alias | _: SubqueryExpression | _: AttributeReference | _: VariableReference => + Some(CollationContext(expr.dataType, Implicit)) + + case _: Literal => + Some(CollationContext(expr.dataType, Default)) + + case extract: ExtractValue => + findCollationContext(extract.child) + .map(cc => CollationContext(extract.dataType, cc.strength)) + + case _ if expr.children.isEmpty => + Some(CollationContext(expr.dataType, Default)) + + case _ => + expr.children + .flatMap(findCollationContext) + .foldLeft(Option.empty[CollationContext]) { + case (Some(left), right) => + collationPrecedenceWinner(left, right) + case (None, right) => + Some(right) + } + } + + contextOpt.foreach(expr.setTagValue(COLLATION_CONTEXT_TAG, _)) + contextOpt + } + + /** + * Returns the collation context that wins in precedence between left and right. + */ + private def collationPrecedenceWinner( + left: CollationContext, + right: CollationContext): Option[CollationContext] = { + + val (leftStringType, rightStringType) = + (extractStringType(left.dataType), extractStringType(right.dataType)) match { + case (Some(l), Some(r)) => + (l, r) + case (None, None) => + return None + case (Some(_), None) => + return Some(left) + case (None, Some(_)) => + return Some(right) + } + + (left.strength, right.strength) match { + case (Explicit, Explicit) if leftStringType != rightStringType => + throw QueryCompilationErrors.explicitCollationMismatchError( + Seq(leftStringType, rightStringType)) + + case (Explicit, _) | (_, Explicit) => + if (left.strength == Explicit) Some(left) else Some(right) + + case (Implicit, Implicit) if leftStringType != rightStringType => + throw QueryCompilationErrors.implicitCollationMismatchError( + Seq(leftStringType, rightStringType)) + + case (Implicit, _) | (_, Implicit) => + if (left.strength == Implicit) Some(left) else Some(right) + + case (Default, Default) if leftStringType != rightStringType => + throw QueryCompilationErrors.explicitCollationMismatchError( + Seq(leftStringType, rightStringType)) + + case _ => + Some(left) } } } + +/** + * Represents the strength of collation used for determining precedence in collation resolution. + */ +private sealed trait CollationStrength {} + + private object CollationStrength { + case object Explicit extends CollationStrength {} + case object Implicit extends CollationStrength {} + case object Default extends CollationStrength {} +} + +/** + * Encapsulates the context for collation, including data type and strength. + * + * @param dataType The data type associated with this collation context. + * @param strength The strength level of the collation, which determines its precedence. + */ +private case class CollationContext(dataType: DataType, strength: CollationStrength) {} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index 6549738a4b0ac..d8230292f91f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StringType} @@ -64,7 +64,7 @@ class ResolveDefaultStringType(replaceWithTempType: Boolean) extends Rule[Logica } private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { - case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true + case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs | _: V1CreateTablePlan => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala index 5b4d76a2a73ed..3fc4b71c986ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala @@ -415,7 +415,7 @@ abstract class TypeCoercionHelper { if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => val newChildren = c.children.map { e => - implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e) + implicitCast(e, StringType).getOrElse(e) } c.copy(children = newChildren) case other => other @@ -465,7 +465,7 @@ abstract class TypeCoercionHelper { if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { children.tail.map { e => - implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e) + implicitCast(e, StringType).getOrElse(e) } } else { children.tail diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3b8d4e09905e5..2013cd8d6e636 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -93,6 +93,7 @@ object ExtractValue { trait ExtractValue extends Expression { override def nullIntolerant: Boolean = true final override val nodePatterns: Seq[TreePattern] = Seq(EXTRACT_VALUE) + val child: Expression } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index b465e0e11612f..bffefbf2f4fea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -459,6 +459,8 @@ trait V2CreateTableAsSelectPlan newQuery: LogicalPlan): V2CreateTableAsSelectPlan } +trait V1CreateTablePlan extends LogicalPlan {} + /** A trait used for logical plan nodes that create or replace V2 table definitions. */ trait V2CreateTablePlan extends LogicalPlan { def name: LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index d9367d92d462e..eb9d5813cff7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V1CreateTablePlan} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._ case class CreateTable( tableDesc: CatalogTable, mode: SaveMode, - query: Option[LogicalPlan]) extends LogicalPlan { + query: Option[LogicalPlan]) extends LogicalPlan with V1CreateTablePlan { assert(tableDesc.provider.isDefined, "The table to be created must have a provider.") if (query.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index af9462a4bf46d..21c3104690b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -187,6 +187,32 @@ abstract class DefaultCollationTestSuite extends QueryTest with SharedSparkSessi } } + test("subsequent analyzer iterations correctly resolve default string types") { + // since concat coercion happens after resolving default types this test + // makes sure that we are correctly resolving the default string types + // in subsequent analyzer iterations + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable + |USING $dataSource AS + |SELECT CONCAT(X'68656C6C6F', 'world') AS c1 + |""".stripMargin) + + checkAnswer(sql(s"SELECT c1 FROM $testTable"), Seq(Row("helloworld"))) + } + + // ELT is similar + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable + |USING $dataSource AS + |SELECT ELT(1, X'68656C6C6F', 'world') AS c1 + |""".stripMargin) + + checkAnswer(sql(s"SELECT c1 FROM $testTable"), Seq(Row("hello"))) + } + } + // endregion // region DML tests From 658331a25a8d3f407ebb5d0d82cc901cb6dcd53f Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 20 Nov 2024 19:52:08 +0100 Subject: [PATCH 30/54] fix tests that use schema as a parser --- .../sql/CollationSQLExpressionsSuite.scala | 254 +++++++++++++++++- 1 file changed, 252 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 22dabdb6ee1e7..637a43ff2038d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat -import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException, SparkThrowable} +import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException} import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Mode @@ -352,6 +353,58 @@ class CollationSQLExpressionsSuite }) } + test("Support CsvToStructs csv expression with collation") { + case class CsvToStructsTestCase( + input: String, + collationName: String, + schema: String, + options: String, + result: Row, + structFields: Seq[StructField] + ) + + val testCases = Seq( + CsvToStructsTestCase("1", "UTF8_BINARY", "'a INT'", "", + Row(1), Seq( + StructField("a", IntegerType, nullable = true) + )), + CsvToStructsTestCase("true, 0.8", "UTF8_LCASE", "'A BOOLEAN, B DOUBLE'", "", + Row(true, 0.8), Seq( + StructField("A", BooleanType, nullable = true), + StructField("B", DoubleType, nullable = true) + )), + CsvToStructsTestCase("\"Spark\"", "UNICODE", "'a STRING'", "", + Row("Spark"), Seq( + StructField("a", StringType, nullable = true) + )), + CsvToStructsTestCase("\"Spark\"", "UTF8_BINARY", "'a STRING COLLATE UNICODE'", "", + Row("Spark"), Seq( + StructField("a", StringType("UNICODE"), nullable = true) + )), + CsvToStructsTestCase("26/08/2015", "UTF8_BINARY", "'time Timestamp'", + ", map('timestampFormat', 'dd/MM/yyyy')", Row( + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse("2015-08-26 00:00:00.0") + ), Seq( + StructField("time", TimestampType, nullable = true) + )) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select from_csv('${t.input}', ${t.schema} ${t.options}) + |""".stripMargin + // Result + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + val queryResult = testQuery.collect().head + checkAnswer(testQuery, Row(t.result)) + val dataType = StructType(t.structFields) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } test("Support SchemaOfCsv csv expression with collation") { case class SchemaOfCsvTestCase( @@ -1031,7 +1084,7 @@ class CollationSQLExpressionsSuite val query = s"SELECT raise_error('${t.errorMessage}')" // Result & data type checkError( - exception = intercept[SparkThrowable] { + exception = intercept[SparkRuntimeException] { sql(query).collect() }, condition = "USER_RAISED_EXCEPTION", @@ -1219,6 +1272,90 @@ class CollationSQLExpressionsSuite ) } + test("Support XmlToStructs xml expression with collation") { + case class XmlToStructsTestCase( + input: String, + collationName: String, + schema: String, + options: String, + result: Row, + structFields: Seq[StructField] + ) + + val testCases = Seq( + XmlToStructsTestCase("

1

", "UTF8_BINARY", "'a INT'", "", + Row(1), Seq( + StructField("a", IntegerType, nullable = true) + )), + XmlToStructsTestCase("

true0.8

", "UTF8_LCASE", + "'A BOOLEAN, B DOUBLE'", "", Row(true, 0.8), Seq( + StructField("A", BooleanType, nullable = true), + StructField("B", DoubleType, nullable = true) + )), + XmlToStructsTestCase("

Spark

", "UNICODE", "'s STRING'", "", + Row("Spark"), Seq( + StructField("s", StringType, nullable = true) + )), + XmlToStructsTestCase("

Spark

", "UTF8_BINARY", "'s STRING COLLATE UNICODE'", "", + Row("Spark"), Seq( + StructField("s", StringType("UNICODE"), nullable = true) + )), + XmlToStructsTestCase("

", "UNICODE_CI", "'time Timestamp'", + ", map('timestampFormat', 'dd/MM/yyyy')", Row( + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse("2015-08-26 00:00:00.0") + ), Seq( + StructField("time", TimestampType, nullable = true) + )) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select from_xml('${t.input}', ${t.schema} ${t.options}) + |""".stripMargin + // Result + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StructType(t.structFields) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support SchemaOfXml xml expression with collation") { + case class SchemaOfXmlTestCase( + input: String, + collationName: String, + result: String + ) + + val testCases = Seq( + SchemaOfXmlTestCase("

1

", "UTF8_BINARY", "STRUCT"), + SchemaOfXmlTestCase("

true0.8

", "UTF8_LCASE", + "STRUCT"), + SchemaOfXmlTestCase("

", "UNICODE", "STRUCT<>"), + SchemaOfXmlTestCase("

123

", "UNICODE_CI", + "STRUCT>") + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select schema_of_xml('${t.input}') + |""".stripMargin + // Result + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + test("Support StructsToXml xml expression with collation") { case class StructsToXmlTestCase( input: String, @@ -1332,6 +1469,60 @@ class CollationSQLExpressionsSuite } } + test("Support VariantGet & TryVariantGet variant expressions with collation") { + case class VariantGetTestCase( + input: String, + path: String, + variantType: String, + collationName: String, + result: Any, + resultType: DataType + ) + + val testCases = Seq( + VariantGetTestCase("{\"a\": 1}", "$.a", "int", "UTF8_BINARY", 1, IntegerType), + VariantGetTestCase("{\"a\": 1}", "$.b", "int", "UTF8_LCASE", null, IntegerType), + VariantGetTestCase("[1, \"2\"]", "$[1]", "string", "UNICODE", "2", + StringType), + VariantGetTestCase("[1, \"2\"]", "$[1]", "string collate unicode", "UTF8_BINARY", "2", + StringType("UNICODE")), + VariantGetTestCase("[1, \"2\"]", "$[2]", "string", "UNICODE_CI", null, + StringType), + VariantGetTestCase("[1, \"2\"]", "$[2]", "string collate unicode_CI", "UTF8_BINARY", null, + StringType("UNICODE_CI")) + ) + + // Supported collations (VariantGet) + testCases.foreach(t => { + val query = + s""" + |SELECT variant_get(parse_json('${t.input}'), '${t.path}', '${t.variantType}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + val testResult = testQuery.collect().map(_.toString()).mkString("") + assert(testResult === "[" + t.result + "]") // can't use checkAnswer for Variant + assert(testQuery.schema.fields.head.dataType.sameType(t.resultType)) + } + }) + + // Supported collations (TryVariantGet) + testCases.foreach(t => { + val query = + s""" + |SELECT try_variant_get(parse_json('${t.input}'), '${t.path}', '${t.variantType}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + val testResult = testQuery.collect().map(_.toString()).mkString("") + assert(testResult === "[" + t.result + "]") // can't use checkAnswer for Variant + assert(testQuery.schema.fields.head.dataType.sameType(t.resultType)) + } + }) + } + test("Handle invalid JSON for TryParseJson variant expression with collation") { // try_parse_json shouldn't throw an exception when the string is not valid JSON value val json = "{\"a\":1,]" @@ -1398,6 +1589,65 @@ class CollationSQLExpressionsSuite } } + test("Support VariantExplode variant expressions with collation") { + case class VariantExplodeTestCase( + input: String, + collationName: String, + result: String, + resultType: Seq[StructField] + ) + + val testCases = Seq( + VariantExplodeTestCase("[\"hello\", \"world\"]", "UTF8_BINARY", + Row(0, "null", "\"hello\"").toString() + Row(1, "null", "\"world\"").toString(), + Seq[StructField]( + StructField("pos", IntegerType, nullable = false), + StructField("key", StringType("UTF8_BINARY")), + StructField("value", VariantType, nullable = false) + ) + ), + VariantExplodeTestCase("[\"Spark\", \"SQL\"]", "UTF8_LCASE", + Row(0, "null", "\"Spark\"").toString() + Row(1, "null", "\"SQL\"").toString(), + Seq[StructField]( + StructField("pos", IntegerType, nullable = false), + StructField("key", StringType("UTF8_LCASE")), + StructField("value", VariantType, nullable = false) + ) + ), + VariantExplodeTestCase("{\"a\": true, \"b\": 3.14}", "UNICODE", + Row(0, "a", "true").toString() + Row(1, "b", "3.14").toString(), + Seq[StructField]( + StructField("pos", IntegerType, nullable = false), + StructField("key", StringType("UNICODE")), + StructField("value", VariantType, nullable = false) + ) + ), + VariantExplodeTestCase("{\"A\": 9.99, \"B\": false}", "UNICODE_CI", + Row(0, "A", "9.99").toString() + Row(1, "B", "false").toString(), + Seq[StructField]( + StructField("pos", IntegerType, nullable = false), + StructField("key", StringType("UNICODE_CI")), + StructField("value", VariantType, nullable = false) + ) + ) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT * from variant_explode(parse_json('${t.input}')) + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + val testResult = testQuery.collect().map(_.toString()).mkString("") + assert(testResult === t.result) // can't use checkAnswer for Variant + assert(testQuery.schema.fields.sameElements(t.resultType)) + } + }) + } + test("Support SchemaOfVariant variant expressions with collation") { case class SchemaOfVariantTestCase( input: String, From dd823be60441af14873739c5556f6d7a3f2f68a9 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 20 Nov 2024 20:02:02 +0100 Subject: [PATCH 31/54] format methods --- .../sql/CollationSQLExpressionsSuite.scala | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 637a43ff2038d..8fd9de490334c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -1469,7 +1469,46 @@ class CollationSQLExpressionsSuite } } - test("Support VariantGet & TryVariantGet variant expressions with collation") { + test("Handle invalid JSON for TryParseJson variant expression with collation") { + // try_parse_json shouldn't throw an exception when the string is not valid JSON value + val json = "{\"a\":1,]" + val query = s"SELECT try_parse_json('$json');" + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + val testQuery = sql(query) + val testResult = testQuery.collect().map(_.toString()).mkString("") + assert(testResult === s"[null]") + } + } + + test("Support IsVariantNull variant expressions with collation") { + case class IsVariantNullTestCase( + input: String, + collationName: String, + result: Boolean + ) + + val testCases = Seq( + IsVariantNullTestCase("'null'", "UTF8_BINARY", result = true), + IsVariantNullTestCase("'\"null\"'", "UTF8_LCASE", result = false), + IsVariantNullTestCase("'13'", "UNICODE", result = false), + IsVariantNullTestCase("null", "UNICODE_CI", result = false) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT is_variant_null(parse_json(${t.input})) + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + } + }) + } + + test("Support VariantGet & TryVariantGet variant expressions with collation") { case class VariantGetTestCase( input: String, path: String, @@ -1523,45 +1562,6 @@ class CollationSQLExpressionsSuite }) } - test("Handle invalid JSON for TryParseJson variant expression with collation") { - // try_parse_json shouldn't throw an exception when the string is not valid JSON value - val json = "{\"a\":1,]" - val query = s"SELECT try_parse_json('$json');" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { - val testQuery = sql(query) - val testResult = testQuery.collect().map(_.toString()).mkString("") - assert(testResult === s"[null]") - } - } - - test("Support IsVariantNull variant expressions with collation") { - case class IsVariantNullTestCase( - input: String, - collationName: String, - result: Boolean - ) - - val testCases = Seq( - IsVariantNullTestCase("'null'", "UTF8_BINARY", result = true), - IsVariantNullTestCase("'\"null\"'", "UTF8_LCASE", result = false), - IsVariantNullTestCase("'13'", "UNICODE", result = false), - IsVariantNullTestCase("null", "UNICODE_CI", result = false) - ) - - // Supported collations - testCases.foreach(t => { - val query = - s""" - |SELECT is_variant_null(parse_json(${t.input})) - |""".stripMargin - // Result & data type - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { - val testQuery = sql(query) - checkAnswer(testQuery, Row(t.result)) - } - }) - } - test("Handle invalid JSON for VariantGet variant expression with collation") { // variant_get should throw an exception if the cast fails val json = "[1, \"Spark\"]" From ca7ec5a0ceb80bf35b61df318b84b00aa3cd532e Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 20 Nov 2024 20:09:09 +0100 Subject: [PATCH 32/54] formatting --- .../org/apache/spark/sql/CollationSQLExpressionsSuite.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 8fd9de490334c..cf494fcd87451 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat +import scala.collection.immutable.Seq + import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException} import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.expressions._ @@ -1508,7 +1510,7 @@ class CollationSQLExpressionsSuite }) } - test("Support VariantGet & TryVariantGet variant expressions with collation") { + test("Support VariantGet & TryVariantGet variant expressions with collation") { case class VariantGetTestCase( input: String, path: String, From 7d6dccc1a734a8cce7d29134ce558e9b1a202fa8 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 21 Nov 2024 18:18:40 +0100 Subject: [PATCH 33/54] move resolve to spark core --- .../sql/catalyst/analysis/Analyzer.scala | 5 -- .../analysis/CollationTypeCoercion.scala | 13 +++-- .../analysis/ResolveInlineTables.scala | 13 ++++- .../sql/catalyst/parser/AstBuilder.scala | 29 ++-------- .../catalyst/plans/logical/v2Commands.scala | 2 - .../apache/spark/sql/util/SchemaUtils.scala | 8 +++ .../analysis/ResolveDefaultStringType.scala | 58 ++++++++++++++----- .../spark/sql/execution/datasources/ddl.scala | 4 +- .../internal/BaseSessionStateBuilder.scala | 3 +- .../sql/hive/HiveSessionStateBuilder.scala | 3 +- 10 files changed, 84 insertions(+), 54 deletions(-) rename sql/{catalyst => core}/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala (75%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2d9ac0a34708f..bed7bea61597f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -316,11 +316,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveProcedures :: BindProcedures :: ResolveTableSpec :: - // Due to how tree transformations work and StringType object being equal to - // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice - // to ensure the correct results for occurrences of default string type. - new ResolveDefaultStringType(replaceWithTempType = true) :: - new ResolveDefaultStringType(replaceWithTempType = false) :: ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index 4c7933848dcb5..240cd557de142 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveS import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType} +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} import org.apache.spark.sql.util.SchemaUtils /** @@ -178,9 +178,14 @@ object CollationTypeCoercion { * if expression has StringType in the first place. */ def castStringType(expr: Expression, st: StringType): Expression = { - castStringType(expr.dataType, st) - .map(dt => Cast(expr, dt)) - .getOrElse(expr) + castStringType(expr.dataType, st) match { + case Some(dt) => expr match { + case lit: Literal => lit.copy(dataType = dt) + case cast: Cast => cast.copy(dataType = dt) + case _ => Cast(expr, dt) + } + case _ => expr + } } private def castStringType(inType: DataType, castType: StringType): Option[DataType] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 62f3997491c07..a1498a5df4acb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.EvaluateUnresolvedInlineTable -import org.apache.spark.sql.catalyst.expressions.EvalHelper +import org.apache.spark.sql.catalyst.expressions.{EvalHelper, Expression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.AlwaysProcess +import org.apache.spark.sql.util.SchemaUtils /** * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[ResolvedInlineTable]]. @@ -29,8 +30,16 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess object ResolveInlineTables extends Rule[LogicalPlan] with EvalHelper { override def apply(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsWithPruning(AlwaysProcess.fn, ruleId) { - case table: UnresolvedInlineTable if table.expressionsResolved => + case table: UnresolvedInlineTable if canResolveInlineTable(table) => EvaluateUnresolvedInlineTable.evaluateUnresolvedInlineTable(table) } } + + def canResolveInlineTable(table: UnresolvedInlineTable): Boolean = { + table.expressionsResolved && !table.expressions.exists(hasUnresolvedStringType) + } + + private def hasUnresolvedStringType(expression: Expression): Boolean = { + expression.exists(e => SchemaUtils.hasDefaultStringType(e.dataType)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 47660c7ec7dbc..cef0e3fc5261b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2126,9 +2126,8 @@ class AstBuilder extends DataTypeAstBuilder Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } - val rowSeq = rows.toSeq - val unresolvedTable = UnresolvedInlineTable(aliases, rowSeq) - val table = if (canEagerlyEvaluateInlineTable(rowSeq)) { + val unresolvedTable = UnresolvedInlineTable(aliases, rows.toSeq) + val table = if (canEagerlyEvaluateInlineTable(unresolvedTable)) { EvaluateUnresolvedInlineTable.evaluate(unresolvedTable) } else { unresolvedTable @@ -2137,27 +2136,11 @@ class AstBuilder extends DataTypeAstBuilder } /** - * Determines if the inline table can be eagerly evaluated. Eager evaluation is not allowed - * if the session-level collation is set and there are string literals present in the expressions, - * because the result may depend on the collation of the input. + * Determines if the inline table can be eagerly evaluated. */ - private def canEagerlyEvaluateInlineTable(rows: Seq[Seq[Expression]]): Boolean = { - val configSet = conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED) - val sessionCollationUnchanged = conf.defaultStringType == StringType - - configSet && - (sessionCollationUnchanged || !rows.exists(_.exists(containsStringLiteral))) - } - - private def containsStringLiteral(expression: Expression): Boolean = { - def inner(expr: Expression): Boolean = expr match { - case Literal(_, dataType) => - dataType.existsRecursively(_.isInstanceOf[StringType]) - case _ => - expr.children.exists(inner) - } - - expression.resolved && inner(expression) + private def canEagerlyEvaluateInlineTable(table: UnresolvedInlineTable): Boolean = { + conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED) && + ResolveInlineTables.canResolveInlineTable(table) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index bffefbf2f4fea..b465e0e11612f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -459,8 +459,6 @@ trait V2CreateTableAsSelectPlan newQuery: LogicalPlan): V2CreateTableAsSelectPlan } -trait V1CreateTablePlan extends LogicalPlan {} - /** A trait used for logical plan nodes that create or replace V2 table definitions. */ trait V2CreateTablePlan extends LogicalPlan { def name: LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index 1e0bac331dc75..4c2e0a2bee9eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -304,6 +304,14 @@ private[spark] object SchemaUtils { } } + /** + * Checks if a given data type has a default string type (StringType object), + * which should be resolved to a specific string type in the analyzer. + */ + def hasDefaultStringType(dt: DataType): Boolean = { + dt.existsRecursively(_.eq(StringType)) + } + /** * Replaces any collated string type with non collated StringType * recursively in the given data type. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala similarity index 75% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala rename to sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala index d8230292f91f5..ffe08d580e102 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala @@ -18,15 +18,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StringType} - - -case class TemporaryStringType( - override val collationId: Int = SQLConf.get.defaultStringType.collationId) - extends StringType(collationId) {} +import org.apache.spark.sql.types.{DataType, StringType, StructType} /** * Resolves default string types in DDL commands. For DML commands, the default string type is @@ -40,13 +36,21 @@ case class TemporaryStringType( */ class ResolveDefaultStringType(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - if (isDefaultSessionCollationUsed) { - plan - } else if (isDDLCommand(plan)) { + val newPlan = if (isDDLCommand(plan)) { transformDDL(plan) } else { + val newType = stringTypeForDMLCommand transformPlan(plan, SQLConf.get.defaultStringType) } + + if (!replaceWithTempType || newPlan.fastEquals(plan)) { + newPlan + } else { + // Due to how tree transformations work and StringType object being equal to + // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice + // to ensure the correct results for occurrences of default string type. + ResolveDefaultStringTypeWithoutTempType.apply(newPlan) + } } private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType @@ -56,20 +60,35 @@ class ResolveDefaultStringType(replaceWithTempType: Boolean) extends Rule[Logica * UTF8_BINARY). */ private def stringTypeForDDLCommand(table: LogicalPlan): StringType = - StringType(0) + StringType("UTF8_BINARY") - private def isDDLCommand(plan: LogicalPlan): Boolean = plan match { - case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true + /** Returns the default string type that should be used in DML commands. */ + private def stringTypeForDMLCommand: StringType = + if (isDefaultSessionCollationUsed) { + StringType("UTF8_BINARY") + } else { + SQLConf.get.defaultStringType + } + + private def isDDLCommand(plan: LogicalPlan): Boolean = plan exists { + case _: CreateTable | _: AddColumns | _: ReplaceColumns | _: AlterColumn => true case _ => isCreateOrAlterPlan(plan) } private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { - case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs | _: V1CreateTablePlan => true + case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true case _ => false } private def transformDDL(plan: LogicalPlan): LogicalPlan = { plan resolveOperators { + case createTable: CreateTable => + val newType = stringTypeForDDLCommand(createTable) + val newSchema = replaceDefaultStringType(createTable.tableDesc.schema, newType) + .asInstanceOf[StructType] + val withNewSchema = createTable.copy(createTable.tableDesc.copy(schema = newSchema)) + transformPlan(withNewSchema, newType) + case p if isCreateOrAlterPlan(p) => val newType = stringTypeForDDLCommand(p) transformPlan(p, newType) @@ -152,3 +171,14 @@ class ResolveDefaultStringType(replaceWithTempType: Boolean) extends Rule[Logica } } } + +case object ResolveDefaultStringType + extends ResolveDefaultStringType(replaceWithTempType = true) {} + +case object ResolveDefaultStringTypeWithoutTempType + extends ResolveDefaultStringType(replaceWithTempType = false) {} + +case class TemporaryStringType(override val collationId: Int = + SQLConf.get.defaultStringType.collationId) + extends StringType(collationId) {} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index eb9d5813cff7b..d9367d92d462e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V1CreateTablePlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._ case class CreateTable( tableDesc: CatalogTable, mode: SaveMode, - query: Option[LogicalPlan]) extends LogicalPlan with V1CreateTablePlan { + query: Option[LogicalPlan]) extends LogicalPlan { assert(tableDesc.provider.isDefined, "The table to be created must have a provider.") if (query.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 59a873ef982fe..a438c63c19278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.artifact.ArtifactManager -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringType, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -204,6 +204,7 @@ abstract class BaseSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: + ResolveDefaultStringType +: new ResolveSessionCatalog(this.catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index dbeb8607facc2..9e48a9c3976f9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringType, ResolveSessionCatalog, ResolveTranspose} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -90,6 +90,7 @@ class HiveSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: + ResolveDefaultStringType +: new ResolveSessionCatalog(catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: From 0b0364f70734ed6c6cc4b801c039497e8d1ffbcd Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 21 Nov 2024 20:32:30 +0100 Subject: [PATCH 34/54] use -1 id for temporary sstring type --- .../analysis/CollationTypeCoercion.scala | 2 + .../expressions/stringExpressions.scala | 2 +- .../spark/sql/catalyst/rules/Rule.scala | 2 + .../sql/catalyst/rules/RuleExecutor.scala | 7 +++- ...esolveDefaultStringTypeWithTempType.scala} | 42 ++++++++++--------- .../internal/BaseSessionStateBuilder.scala | 10 +++-- .../sql/hive/HiveSessionStateBuilder.scala | 8 +++- 7 files changed, 47 insertions(+), 26 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/{ResolveDefaultStringType.scala => ResolveDefaultStringTypeWithTempType.scala} (86%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index 240cd557de142..9cbe164e90c38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -218,6 +218,8 @@ object CollationTypeCoercion { private def findLeastCommonStringType(expressions: Seq[Expression]): Option[StringType] = { if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) { return None + } else if (expressions.exists(e => SchemaUtils.hasDefaultStringType(e.dataType))) { + return None } val collationContextWinner = expressions.foldLeft(findCollationContext(expressions.head)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c97920619ba4d..2ea53350fea36 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1863,7 +1863,7 @@ trait PadExpressionBuilderBase extends ExpressionBuilder { BinaryPad(funcName, expressions(0), expressions(1), Literal(Array[Byte](0))) } else { createStringPad(expressions(0), - expressions(1), Literal.create(" ", SQLConf.get.defaultStringType)) + expressions(1), Literal(" ")) } } else if (numArgs == 3) { if (expressions(0).dataType == BinaryType && expressions(2).dataType == BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index e263f827d2847..b14cf0476b09d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -32,5 +32,7 @@ abstract class Rule[TreeType <: TreeNode[_]] extends SQLConfHelper with Logging if (className endsWith "$") className.dropRight(1) else className } + def requiresRestart: Boolean = false + def apply(plan: TreeType): TreeType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 76d36fab2096a..4999dbf575bc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.rules +import scala.collection.mutable.ListBuffer + import org.apache.spark.SparkException import org.apache.spark.internal.{Logging, MessageWithContext} import org.apache.spark.internal.LogKeys._ @@ -236,6 +238,8 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { // Run until fix point or the max number of iterations as specified in the strategy. while (continue) { + val effectiveRules = new ListBuffer[Rule[TreeType]]() + curPlan = batch.rules.foldLeft(curPlan) { case (plan, rule) => val startTime = System.nanoTime() @@ -246,6 +250,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (effective) { queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) + effectiveRules.addOne(rule) planChangeLogger.logRule(rule.ruleName, plan, result) // Run the plan changes validation after each rule. if (fullValidation || lightweightValidation) { @@ -303,7 +308,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { continue = false } - if (curPlan.fastEquals(lastPlan)) { + if (curPlan.fastEquals(lastPlan) && !effectiveRules.exists(_.requiresRestart)) { logTrace( s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.") continue = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala similarity index 86% rename from sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala rename to sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala index ffe08d580e102..057892c424c30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala @@ -34,25 +34,18 @@ import org.apache.spark.sql.types.{DataType, StringType, StructType} * [[TemporaryStringType]] object in cases where the old type and new are equal and thus would * not change the plan after transformation. */ -class ResolveDefaultStringType(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { +class ResolveDefaultStringTypeWithTempType(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val newPlan = if (isDDLCommand(plan)) { + if (isDDLCommand(plan)) { transformDDL(plan) } else { val newType = stringTypeForDMLCommand - transformPlan(plan, SQLConf.get.defaultStringType) - } - - if (!replaceWithTempType || newPlan.fastEquals(plan)) { - newPlan - } else { - // Due to how tree transformations work and StringType object being equal to - // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice - // to ensure the correct results for occurrences of default string type. - ResolveDefaultStringTypeWithoutTempType.apply(newPlan) + transformPlan(plan, newType) } } + override def requiresRestart: Boolean = true + private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType /** @@ -154,8 +147,20 @@ class ResolveDefaultStringType(replaceWithTempType: Boolean) extends Rule[Logica private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = { dataType.transformRecursively { - case currentType if isDefaultStringType(currentType) => - if (replaceWithTempType && currentType == newType) TemporaryStringType() else newType + case currentType: StringType if isDefaultStringType(currentType) => + if (replaceWithTempType && currentType == newType) { + getTemporaryStringType(currentType) + } else { + newType + } + } + } + + private def getTemporaryStringType(forType: StringType): StringType = { + if (forType.collationId == 0) { + TemporaryStringType(1) + } else { + TemporaryStringType(0) } } @@ -172,13 +177,12 @@ class ResolveDefaultStringType(replaceWithTempType: Boolean) extends Rule[Logica } } -case object ResolveDefaultStringType - extends ResolveDefaultStringType(replaceWithTempType = true) {} +case object ResolveDefaultStringTypeWithTempType + extends ResolveDefaultStringTypeWithTempType(replaceWithTempType = true) {} case object ResolveDefaultStringTypeWithoutTempType - extends ResolveDefaultStringType(replaceWithTempType = false) {} + extends ResolveDefaultStringTypeWithTempType(replaceWithTempType = false) {} -case class TemporaryStringType(override val collationId: Int = - SQLConf.get.defaultStringType.collationId) +case class TemporaryStringType(override val collationId: Int) extends StringType(collationId) {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index a438c63c19278..7881e6b6e6cc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} +import org.apache.spark.sql._ import org.apache.spark.sql.artifact.ArtifactManager -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringType, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringTypeWithoutTempType, ResolveDefaultStringTypeWithTempType, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -204,7 +204,11 @@ abstract class BaseSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: - ResolveDefaultStringType +: + // Due to how tree transformations work and StringType object being equal to + // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice + // to ensure the correct results for occurrences of default string type. + ResolveDefaultStringTypeWithTempType +: + ResolveDefaultStringTypeWithoutTempType +: new ResolveSessionCatalog(this.catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 9e48a9c3976f9..95114a4d494b3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringType, ResolveSessionCatalog, ResolveTranspose} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringTypeWithoutTempType, ResolveDefaultStringTypeWithTempType, ResolveSessionCatalog, ResolveTranspose} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -90,7 +90,11 @@ class HiveSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: - ResolveDefaultStringType +: + // Due to how tree transformations work and StringType object being equal to + // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice + // to ensure the correct results for occurrences of default string type. + ResolveDefaultStringTypeWithTempType +: + ResolveDefaultStringTypeWithoutTempType +: new ResolveSessionCatalog(catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: From 052ff47339152abf5b6cc94e35e1605959ea132c Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 21 Nov 2024 22:01:21 +0100 Subject: [PATCH 35/54] move the rule back to the catalyst --- .../sql/catalyst/analysis/Analyzer.scala | 1 + .../analysis/CollationTypeCoercion.scala | 3 ++ .../analysis/ResolveDefaultStringTypes.scala} | 39 +++++++++---------- .../catalyst/plans/logical/v2Commands.scala | 3 ++ .../spark/sql/catalyst/rules/Rule.scala | 2 - .../sql/catalyst/rules/RuleExecutor.scala | 7 +--- .../spark/sql/execution/datasources/ddl.scala | 4 +- .../internal/BaseSessionStateBuilder.scala | 9 +---- .../collation/DefaultCollationTestSuite.scala | 21 ++++++++++ .../sql/hive/HiveSessionStateBuilder.scala | 7 +--- 10 files changed, 53 insertions(+), 43 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala => catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala} (84%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bed7bea61597f..3e4db4c39661e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -319,6 +319,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: + ResolveDefaultStringTypes :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index 9cbe164e90c38..edbbe4835730a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -217,8 +217,11 @@ object CollationTypeCoercion { */ private def findLeastCommonStringType(expressions: Seq[Expression]): Option[StringType] = { if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) { + // if there are no collations return None return None } else if (expressions.exists(e => SchemaUtils.hasDefaultStringType(e.dataType))) { + // if there are default string types, return None as they need to be resolved + // first in [[ResolveDefaultStringTypes]] return None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 057892c424c30..1a9aaa2592f63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StringType, StructType} +import org.apache.spark.sql.types.{DataType, StringType} /** * Resolves default string types in DDL commands. For DML commands, the default string type is @@ -34,17 +33,24 @@ import org.apache.spark.sql.types.{DataType, StringType, StructType} * [[TemporaryStringType]] object in cases where the old type and new are equal and thus would * not change the plan after transformation. */ -class ResolveDefaultStringTypeWithTempType(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { +class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - if (isDDLCommand(plan)) { + val newPlan = if (isDDLCommand(plan)) { transformDDL(plan) } else { val newType = stringTypeForDMLCommand transformPlan(plan, newType) } - } - override def requiresRestart: Boolean = true + if (!replaceWithTempType || newPlan.fastEquals(plan)) { + newPlan + } else { + // Due to how tree transformations work and StringType object being equal to + // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice + // to ensure the correct results for occurrences of default string type. + ResolveDefaultStringTypesWithoutTempType.apply(newPlan) + } + } private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType @@ -64,24 +70,17 @@ class ResolveDefaultStringTypeWithTempType(replaceWithTempType: Boolean) extends } private def isDDLCommand(plan: LogicalPlan): Boolean = plan exists { - case _: CreateTable | _: AddColumns | _: ReplaceColumns | _: AlterColumn => true + case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true case _ => isCreateOrAlterPlan(plan) } private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { - case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true + case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs | _: V1CreateTablePlan => true case _ => false } private def transformDDL(plan: LogicalPlan): LogicalPlan = { plan resolveOperators { - case createTable: CreateTable => - val newType = stringTypeForDDLCommand(createTable) - val newSchema = replaceDefaultStringType(createTable.tableDesc.schema, newType) - .asInstanceOf[StructType] - val withNewSchema = createTable.copy(createTable.tableDesc.copy(schema = newSchema)) - transformPlan(withNewSchema, newType) - case p if isCreateOrAlterPlan(p) => val newType = stringTypeForDDLCommand(p) transformPlan(p, newType) @@ -177,11 +176,11 @@ class ResolveDefaultStringTypeWithTempType(replaceWithTempType: Boolean) extends } } -case object ResolveDefaultStringTypeWithTempType - extends ResolveDefaultStringTypeWithTempType(replaceWithTempType = true) {} +case object ResolveDefaultStringTypes + extends ResolveDefaultStringTypes(replaceWithTempType = true) {} -case object ResolveDefaultStringTypeWithoutTempType - extends ResolveDefaultStringTypeWithTempType(replaceWithTempType = false) {} +case object ResolveDefaultStringTypesWithoutTempType + extends ResolveDefaultStringTypes(replaceWithTempType = false) {} case class TemporaryStringType(override val collationId: Int) extends StringType(collationId) {} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index b465e0e11612f..26532467c94da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -459,6 +459,9 @@ trait V2CreateTableAsSelectPlan newQuery: LogicalPlan): V2CreateTableAsSelectPlan } +/** A trait used for logical plan nodes that create or replace V1 table definitions. */ +trait V1CreateTablePlan extends LogicalPlan {} + /** A trait used for logical plan nodes that create or replace V2 table definitions. */ trait V2CreateTablePlan extends LogicalPlan { def name: LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index b14cf0476b09d..e263f827d2847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -32,7 +32,5 @@ abstract class Rule[TreeType <: TreeNode[_]] extends SQLConfHelper with Logging if (className endsWith "$") className.dropRight(1) else className } - def requiresRestart: Boolean = false - def apply(plan: TreeType): TreeType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 4999dbf575bc1..76d36fab2096a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.rules -import scala.collection.mutable.ListBuffer - import org.apache.spark.SparkException import org.apache.spark.internal.{Logging, MessageWithContext} import org.apache.spark.internal.LogKeys._ @@ -238,8 +236,6 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { // Run until fix point or the max number of iterations as specified in the strategy. while (continue) { - val effectiveRules = new ListBuffer[Rule[TreeType]]() - curPlan = batch.rules.foldLeft(curPlan) { case (plan, rule) => val startTime = System.nanoTime() @@ -250,7 +246,6 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (effective) { queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) - effectiveRules.addOne(rule) planChangeLogger.logRule(rule.ruleName, plan, result) // Run the plan changes validation after each rule. if (fullValidation || lightweightValidation) { @@ -308,7 +303,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { continue = false } - if (curPlan.fastEquals(lastPlan) && !effectiveRules.exists(_.requiresRestart)) { + if (curPlan.fastEquals(lastPlan)) { logTrace( s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.") continue = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index d9367d92d462e..eb9d5813cff7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V1CreateTablePlan} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._ case class CreateTable( tableDesc: CatalogTable, mode: SaveMode, - query: Option[LogicalPlan]) extends LogicalPlan { + query: Option[LogicalPlan]) extends LogicalPlan with V1CreateTablePlan { assert(tableDesc.provider.isDefined, "The table to be created must have a provider.") if (query.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 7881e6b6e6cc9..59a873ef982fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable -import org.apache.spark.sql._ +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.artifact.ArtifactManager -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringTypeWithoutTempType, ResolveDefaultStringTypeWithTempType, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -204,11 +204,6 @@ abstract class BaseSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: - // Due to how tree transformations work and StringType object being equal to - // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice - // to ensure the correct results for occurrences of default string type. - ResolveDefaultStringTypeWithTempType +: - ResolveDefaultStringTypeWithoutTempType +: new ResolveSessionCatalog(this.catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index 21c3104690b16..feaaf85bfb512 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -174,6 +174,27 @@ abstract class DefaultCollationTestSuite extends QueryTest with SharedSparkSessi } } + test("ctas with inline table") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable USING $dataSource AS + |SELECT c1, c1 = 'A' as c2 FROM VALUES ('a'), ('A') AS vals(c1) + |""".stripMargin) + + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(1))) + } + + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable USING $dataSource AS + |SELECT c1 as c2 FROM VALUES ('a' = 'A') AS vals(c1) + |""".stripMargin) + + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(0))) + } + } + test("add column") { withSessionCollationAndTable("UTF8_LCASE", testTable) { sql(s"CREATE TABLE $testTable (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 95114a4d494b3..dbeb8607facc2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringTypeWithoutTempType, ResolveDefaultStringTypeWithTempType, ResolveSessionCatalog, ResolveTranspose} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -90,11 +90,6 @@ class HiveSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: - // Due to how tree transformations work and StringType object being equal to - // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice - // to ensure the correct results for occurrences of default string type. - ResolveDefaultStringTypeWithTempType +: - ResolveDefaultStringTypeWithoutTempType +: new ResolveSessionCatalog(catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: From 3e1522648e5351a6245fb08aa81f0fca363c0a53 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Mon, 25 Nov 2024 17:48:59 +0100 Subject: [PATCH 36/54] Revert "move the rule back to the catalyst" This reverts commit 052ff47339152abf5b6cc94e35e1605959ea132c. --- .../sql/catalyst/analysis/Analyzer.scala | 1 - .../analysis/CollationTypeCoercion.scala | 3 -- .../catalyst/plans/logical/v2Commands.scala | 3 -- .../spark/sql/catalyst/rules/Rule.scala | 2 + .../sql/catalyst/rules/RuleExecutor.scala | 7 +++- ...esolveDefaultStringTypeWithTempType.scala} | 39 ++++++++++--------- .../spark/sql/execution/datasources/ddl.scala | 4 +- .../internal/BaseSessionStateBuilder.scala | 9 ++++- .../collation/DefaultCollationTestSuite.scala | 21 ---------- .../sql/hive/HiveSessionStateBuilder.scala | 7 +++- 10 files changed, 43 insertions(+), 53 deletions(-) rename sql/{catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala => core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala} (84%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3e4db4c39661e..bed7bea61597f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -319,7 +319,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: - ResolveDefaultStringTypes :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index edbbe4835730a..9cbe164e90c38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -217,11 +217,8 @@ object CollationTypeCoercion { */ private def findLeastCommonStringType(expressions: Seq[Expression]): Option[StringType] = { if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) { - // if there are no collations return None return None } else if (expressions.exists(e => SchemaUtils.hasDefaultStringType(e.dataType))) { - // if there are default string types, return None as they need to be resolved - // first in [[ResolveDefaultStringTypes]] return None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 26532467c94da..b465e0e11612f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -459,9 +459,6 @@ trait V2CreateTableAsSelectPlan newQuery: LogicalPlan): V2CreateTableAsSelectPlan } -/** A trait used for logical plan nodes that create or replace V1 table definitions. */ -trait V1CreateTablePlan extends LogicalPlan {} - /** A trait used for logical plan nodes that create or replace V2 table definitions. */ trait V2CreateTablePlan extends LogicalPlan { def name: LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index e263f827d2847..b14cf0476b09d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -32,5 +32,7 @@ abstract class Rule[TreeType <: TreeNode[_]] extends SQLConfHelper with Logging if (className endsWith "$") className.dropRight(1) else className } + def requiresRestart: Boolean = false + def apply(plan: TreeType): TreeType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 76d36fab2096a..4999dbf575bc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.rules +import scala.collection.mutable.ListBuffer + import org.apache.spark.SparkException import org.apache.spark.internal.{Logging, MessageWithContext} import org.apache.spark.internal.LogKeys._ @@ -236,6 +238,8 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { // Run until fix point or the max number of iterations as specified in the strategy. while (continue) { + val effectiveRules = new ListBuffer[Rule[TreeType]]() + curPlan = batch.rules.foldLeft(curPlan) { case (plan, rule) => val startTime = System.nanoTime() @@ -246,6 +250,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (effective) { queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) + effectiveRules.addOne(rule) planChangeLogger.logRule(rule.ruleName, plan, result) // Run the plan changes validation after each rule. if (fullValidation || lightweightValidation) { @@ -303,7 +308,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { continue = false } - if (curPlan.fastEquals(lastPlan)) { + if (curPlan.fastEquals(lastPlan) && !effectiveRules.exists(_.requiresRestart)) { logTrace( s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.") continue = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala similarity index 84% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala rename to sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala index 1a9aaa2592f63..057892c424c30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.{DataType, StringType, StructType} /** * Resolves default string types in DDL commands. For DML commands, the default string type is @@ -33,25 +34,18 @@ import org.apache.spark.sql.types.{DataType, StringType} * [[TemporaryStringType]] object in cases where the old type and new are equal and thus would * not change the plan after transformation. */ -class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { +class ResolveDefaultStringTypeWithTempType(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val newPlan = if (isDDLCommand(plan)) { + if (isDDLCommand(plan)) { transformDDL(plan) } else { val newType = stringTypeForDMLCommand transformPlan(plan, newType) } - - if (!replaceWithTempType || newPlan.fastEquals(plan)) { - newPlan - } else { - // Due to how tree transformations work and StringType object being equal to - // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice - // to ensure the correct results for occurrences of default string type. - ResolveDefaultStringTypesWithoutTempType.apply(newPlan) - } } + override def requiresRestart: Boolean = true + private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType /** @@ -70,17 +64,24 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic } private def isDDLCommand(plan: LogicalPlan): Boolean = plan exists { - case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true + case _: CreateTable | _: AddColumns | _: ReplaceColumns | _: AlterColumn => true case _ => isCreateOrAlterPlan(plan) } private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { - case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs | _: V1CreateTablePlan => true + case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true case _ => false } private def transformDDL(plan: LogicalPlan): LogicalPlan = { plan resolveOperators { + case createTable: CreateTable => + val newType = stringTypeForDDLCommand(createTable) + val newSchema = replaceDefaultStringType(createTable.tableDesc.schema, newType) + .asInstanceOf[StructType] + val withNewSchema = createTable.copy(createTable.tableDesc.copy(schema = newSchema)) + transformPlan(withNewSchema, newType) + case p if isCreateOrAlterPlan(p) => val newType = stringTypeForDDLCommand(p) transformPlan(p, newType) @@ -176,11 +177,11 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic } } -case object ResolveDefaultStringTypes - extends ResolveDefaultStringTypes(replaceWithTempType = true) {} +case object ResolveDefaultStringTypeWithTempType + extends ResolveDefaultStringTypeWithTempType(replaceWithTempType = true) {} -case object ResolveDefaultStringTypesWithoutTempType - extends ResolveDefaultStringTypes(replaceWithTempType = false) {} +case object ResolveDefaultStringTypeWithoutTempType + extends ResolveDefaultStringTypeWithTempType(replaceWithTempType = false) {} case class TemporaryStringType(override val collationId: Int) extends StringType(collationId) {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index eb9d5813cff7b..d9367d92d462e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V1CreateTablePlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._ case class CreateTable( tableDesc: CatalogTable, mode: SaveMode, - query: Option[LogicalPlan]) extends LogicalPlan with V1CreateTablePlan { + query: Option[LogicalPlan]) extends LogicalPlan { assert(tableDesc.provider.isDefined, "The table to be created must have a provider.") if (query.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 59a873ef982fe..7881e6b6e6cc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} +import org.apache.spark.sql._ import org.apache.spark.sql.artifact.ArtifactManager -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringTypeWithoutTempType, ResolveDefaultStringTypeWithTempType, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -204,6 +204,11 @@ abstract class BaseSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: + // Due to how tree transformations work and StringType object being equal to + // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice + // to ensure the correct results for occurrences of default string type. + ResolveDefaultStringTypeWithTempType +: + ResolveDefaultStringTypeWithoutTempType +: new ResolveSessionCatalog(this.catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index feaaf85bfb512..21c3104690b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -174,27 +174,6 @@ abstract class DefaultCollationTestSuite extends QueryTest with SharedSparkSessi } } - test("ctas with inline table") { - withSessionCollationAndTable("UTF8_LCASE", testTable) { - sql(s""" - |CREATE TABLE $testTable USING $dataSource AS - |SELECT c1, c1 = 'A' as c2 FROM VALUES ('a'), ('A') AS vals(c1) - |""".stripMargin) - - assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") - checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(1))) - } - - withSessionCollationAndTable("UTF8_LCASE", testTable) { - sql(s""" - |CREATE TABLE $testTable USING $dataSource AS - |SELECT c1 as c2 FROM VALUES ('a' = 'A') AS vals(c1) - |""".stripMargin) - - checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(0))) - } - } - test("add column") { withSessionCollationAndTable("UTF8_LCASE", testTable) { sql(s"CREATE TABLE $testTable (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index dbeb8607facc2..95114a4d494b3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringTypeWithoutTempType, ResolveDefaultStringTypeWithTempType, ResolveSessionCatalog, ResolveTranspose} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -90,6 +90,11 @@ class HiveSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: + // Due to how tree transformations work and StringType object being equal to + // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice + // to ensure the correct results for occurrences of default string type. + ResolveDefaultStringTypeWithTempType +: + ResolveDefaultStringTypeWithoutTempType +: new ResolveSessionCatalog(catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: From 86dc81b56f1b3b9f6d19adf197a36e2282c66d26 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Mon, 25 Nov 2024 21:03:12 +0100 Subject: [PATCH 37/54] initial --- .../analysis/CollationTypeCoercion.scala | 5 ++- .../analysis/ResolveInlineTables.scala | 8 ++-- .../sql/catalyst/expressions/ExprUtils.scala | 25 ++++++++++- .../spark/sql/catalyst/rules/Rule.scala | 2 - .../sql/catalyst/rules/RuleExecutor.scala | 10 ++++- ....scala => ResolveDefaultStringTypes.scala} | 44 +++++++++++++------ .../internal/BaseSessionStateBuilder.scala | 8 +--- .../collation/DefaultCollationTestSuite.scala | 42 ++++++++++++++++-- .../sql/hive/HiveSessionStateBuilder.scala | 8 +--- 9 files changed, 111 insertions(+), 41 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/{ResolveDefaultStringTypeWithTempType.scala => ResolveDefaultStringTypes.scala} (84%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index 9cbe164e90c38..03701a5e5f0aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -217,8 +217,11 @@ object CollationTypeCoercion { */ private def findLeastCommonStringType(expressions: Seq[Expression]): Option[StringType] = { if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) { + // if there are no collated types we don't need to do anything return None - } else if (expressions.exists(e => SchemaUtils.hasDefaultStringType(e.dataType))) { + } else if (expressions.exists(ExprUtils.hasExpressionWithDefaultStringType)) { + // if there are default string types, return None as they need to be resolved + // first in [[ResolveDefaultStringTypes]] return None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index a1498a5df4acb..0dee161013d58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.EvaluateUnresolvedInlineTable -import org.apache.spark.sql.catalyst.expressions.{EvalHelper, Expression} +import org.apache.spark.sql.catalyst.expressions.{EvalHelper, Expression, ExprUtils} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.AlwaysProcess -import org.apache.spark.sql.util.SchemaUtils /** * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[ResolvedInlineTable]]. @@ -39,7 +38,6 @@ object ResolveInlineTables extends Rule[LogicalPlan] with EvalHelper { table.expressionsResolved && !table.expressions.exists(hasUnresolvedStringType) } - private def hasUnresolvedStringType(expression: Expression): Boolean = { - expression.exists(e => SchemaUtils.hasDefaultStringType(e.dataType)) - } + private def hasUnresolvedStringType(expression: Expression): Boolean = + ExprUtils.hasExpressionWithDefaultStringType(expression) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index e65a0200b064f..0102e7df5f9e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnDefinition} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCollation} @@ -223,4 +223,27 @@ object ExprUtils extends EvalHelper with QueryErrorsBase { a.groupingExpressions.foreach(checkValidGroupingExprs) a.aggregateExpressions.foreach(checkValidAggregateExpression) } + + /** + * Returns whether any of the expression in the expression tree of the input + * has a default string type which should be resolved first. + */ + def hasExpressionWithDefaultStringType(expression: Expression): Boolean = { + expression.exists { + case e @ (_: Literal | _: ColumnDefinition | _: Cast) => + hasDefaultStringType(e.dataType) + case _ => + false + } + } + + /** + * Returns whether the input data type has a default string type which should be resolved first. + */ + def hasDefaultStringType(dataType: DataType): Boolean = + dataType.existsRecursively { + case st: StringType => + st.eq(StringType) + case _ => false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index b14cf0476b09d..e263f827d2847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -32,7 +32,5 @@ abstract class Rule[TreeType <: TreeNode[_]] extends SQLConfHelper with Logging if (className endsWith "$") className.dropRight(1) else className } - def requiresRestart: Boolean = false - def apply(plan: TreeType): TreeType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 4999dbf575bc1..483d8cab3270e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.{Logging, MessageWithContext} import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.MDC import org.apache.spark.sql.catalyst.QueryPlanningTracker -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.errors.QueryExecutionErrors @@ -32,6 +32,8 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils object RuleExecutor { + val ONE_MORE_ITER: TreeNodeTag[Unit] = TreeNodeTag[Unit]("one_more_iter") + protected val queryExecutionMeter = QueryExecutionMetering() /** Dump statistics about time spent running specific rules. */ @@ -308,7 +310,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { continue = false } - if (curPlan.fastEquals(lastPlan) && !effectiveRules.exists(_.requiresRestart)) { + if (reachedFixedPoint(lastPlan, curPlan)) { logTrace( s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.") continue = false @@ -322,4 +324,8 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { curPlan } + + private def reachedFixedPoint(oldPlan: TreeType, newPlan: TreeType): Boolean = { + oldPlan.fastEquals(newPlan) && newPlan.getTagValue(RuleExecutor.ONE_MORE_ITER).isEmpty + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala rename to sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 057892c424c30..8cc7cc5edd573 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypeWithTempType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.rules.RuleExecutor.ONE_MORE_ITER import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StringType, StructType} @@ -34,17 +35,33 @@ import org.apache.spark.sql.types.{DataType, StringType, StructType} * [[TemporaryStringType]] object in cases where the old type and new are equal and thus would * not change the plan after transformation. */ -class ResolveDefaultStringTypeWithTempType(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { +class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - if (isDDLCommand(plan)) { + + val newPlan = if (isDDLCommand(plan)) { transformDDL(plan) } else { val newType = stringTypeForDMLCommand transformPlan(plan, newType) } - } - override def requiresRestart: Boolean = true + val finalPlan = if (!replaceWithTempType || newPlan.fastEquals(plan)) { + newPlan + } else { + // Due to how tree transformations work and StringType object being equal to + // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice + // to ensure the correct results for occurrences of default string type. + ResolveDefaultStringTypesWithoutTempType.apply(newPlan) + } + + if (finalPlan == plan && finalPlan == newPlan) { + finalPlan.unsetTagValue(ONE_MORE_ITER) + } else { + finalPlan.setTagValue(ONE_MORE_ITER, ()) + } + + finalPlan + } private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType @@ -74,29 +91,26 @@ class ResolveDefaultStringTypeWithTempType(replaceWithTempType: Boolean) extends } private def transformDDL(plan: LogicalPlan): LogicalPlan = { + val newType = stringTypeForDDLCommand(plan) + plan resolveOperators { case createTable: CreateTable => - val newType = stringTypeForDDLCommand(createTable) val newSchema = replaceDefaultStringType(createTable.tableDesc.schema, newType) .asInstanceOf[StructType] val withNewSchema = createTable.copy(createTable.tableDesc.copy(schema = newSchema)) transformPlan(withNewSchema, newType) case p if isCreateOrAlterPlan(p) => - val newType = stringTypeForDDLCommand(p) transformPlan(p, newType) case addCols: AddColumns => - val newType = stringTypeForDDLCommand(addCols.table) addCols.copy(columnsToAdd = replaceColumnTypes(addCols.columnsToAdd, newType)) case replaceCols: ReplaceColumns => - val newType = stringTypeForDDLCommand(replaceCols.table) replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, newType)) case alter: AlterColumn if alter.dataType.isDefined && hasDefaultStringType(alter.dataType.get) => - val newType = stringTypeForDDLCommand(alter.table) alter.copy(dataType = Some(replaceDefaultStringType(alter.dataType.get, newType))) } } @@ -177,12 +191,14 @@ class ResolveDefaultStringTypeWithTempType(replaceWithTempType: Boolean) extends } } -case object ResolveDefaultStringTypeWithTempType - extends ResolveDefaultStringTypeWithTempType(replaceWithTempType = true) {} +case object ResolveDefaultStringTypes + extends ResolveDefaultStringTypes(replaceWithTempType = true) {} -case object ResolveDefaultStringTypeWithoutTempType - extends ResolveDefaultStringTypeWithTempType(replaceWithTempType = false) {} +case object ResolveDefaultStringTypesWithoutTempType + extends ResolveDefaultStringTypes(replaceWithTempType = false) {} case class TemporaryStringType(override val collationId: Int) - extends StringType(collationId) {} + extends StringType(collationId) { + override def toString: String = s"TemporaryStringType($collationId)" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 7881e6b6e6cc9..8119c9d7b203a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable import org.apache.spark.sql._ import org.apache.spark.sql.artifact.ArtifactManager -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringTypeWithoutTempType, ResolveDefaultStringTypeWithTempType, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringTypes, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -204,11 +204,7 @@ abstract class BaseSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: - // Due to how tree transformations work and StringType object being equal to - // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice - // to ensure the correct results for occurrences of default string type. - ResolveDefaultStringTypeWithTempType +: - ResolveDefaultStringTypeWithoutTempType +: + ResolveDefaultStringTypes +: new ResolveSessionCatalog(this.catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index 21c3104690b16..9d934763ccb87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -187,6 +187,20 @@ abstract class DefaultCollationTestSuite extends QueryTest with SharedSparkSessi } } + test("inline table in CTAS") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable AS + |SELECT * + |FROM (VALUES ('a', 'a' = 'A')) + |AS inline_table(c1, c2); + |""".stripMargin) + + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(0))) + } + } + test("subsequent analyzer iterations correctly resolve default string types") { // since concat coercion happens after resolving default types this test // makes sure that we are correctly resolving the default string types @@ -330,15 +344,35 @@ abstract class DefaultCollationTestSuite extends QueryTest with SharedSparkSessi } } + test("inline table in SELECT") { + withSessionCollation("UTF8_LCASE") { + val df = s""" + |SELECT * + |FROM (VALUES ('a', 'a' = 'A')) + |""".stripMargin + + checkAnswer(sql(df), Seq(Row("a", true))) + } + } + + test("inline table in insert") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING, c2 BOOLEAN) USING $dataSource") + + sql(s"INSERT INTO $testTable VALUES ('a', 'a' = 'A')") + checkAnswer(sql(s"SELECT * FROM $testTable"), Seq(Row("a", true))) + } + } + test("literals in insert inherit session level collation") { withSessionCollationAndTable("UTF8_LCASE", testTable) { sql(s"CREATE TABLE $testTable (c1 BOOLEAN) USING $dataSource") - sql(s"INSERT INTO $testTable VALUES ('a' = 'A')") - - checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c1"), Seq(Row(1))) + sql(s"INSERT INTO $testTable VALUES ('a' = 'A')") sql(s"INSERT INTO $testTable VALUES (array_contains(array('a'), 'A'))") - checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c1"), Seq(Row(2))) + sql(s"INSERT INTO $testTable VALUES (CONCAT(X'68656C6C6F', 'world') = 'HELLOWORLD')") + + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c1"), Seq(Row(3))) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 95114a4d494b3..e42d01a8d7264 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringTypeWithoutTempType, ResolveDefaultStringTypeWithTempType, ResolveSessionCatalog, ResolveTranspose} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringTypes, ResolveSessionCatalog, ResolveTranspose} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -90,11 +90,7 @@ class HiveSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: - // Due to how tree transformations work and StringType object being equal to - // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice - // to ensure the correct results for occurrences of default string type. - ResolveDefaultStringTypeWithTempType +: - ResolveDefaultStringTypeWithoutTempType +: + ResolveDefaultStringTypes +: new ResolveSessionCatalog(catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: From c4dd6630f0db2e62f79ccfdcb4c22bf4873e8aed Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Mon, 25 Nov 2024 22:35:07 +0100 Subject: [PATCH 38/54] add better check for inline tables --- .../analysis/ResolveInlineTables.scala | 7 +- .../sql/catalyst/expressions/ExprUtils.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 35 ++++++++- .../analysis/ResolveDefaultStringTypes.scala | 8 +- .../analyzer-results/collations.sql.out | 78 +++++++++---------- .../analyzer-results/inline-table.sql.out | 4 +- .../postgreSQL/window_part4.sql.out | 2 +- .../sql-tests/results/collations.sql.out | 28 +++---- .../sql-tests/results/inline-table.sql.out | 4 +- .../results/postgreSQL/window_part4.sql.out | 2 +- .../spark/sql/CollationSQLRegexpSuite.scala | 3 +- .../collation/DefaultCollationTestSuite.scala | 14 ++++ .../spark/sql/hive/HiveInspectors.scala | 4 +- 13 files changed, 120 insertions(+), 71 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 0dee161013d58..ff781f32a0ab1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -34,8 +34,11 @@ object ResolveInlineTables extends Rule[LogicalPlan] with EvalHelper { } } - def canResolveInlineTable(table: UnresolvedInlineTable): Boolean = { - table.expressionsResolved && !table.expressions.exists(hasUnresolvedStringType) + def canResolveInlineTable( + table: UnresolvedInlineTable, + checkStringTypes: Boolean = true): Boolean = { + table.expressionsResolved && + (!checkStringTypes || !table.expressions.exists(hasUnresolvedStringType)) } private def hasUnresolvedStringType(expression: Expression): Boolean = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 0102e7df5f9e0..035fc32d27bfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -230,7 +230,7 @@ object ExprUtils extends EvalHelper with QueryErrorsBase { */ def hasExpressionWithDefaultStringType(expression: Expression): Boolean = { expression.exists { - case e @ (_: Literal | _: ColumnDefinition | _: Cast) => + case e @ (_: Literal | _: ColumnDefinition | _: Cast | _: NamedLambdaVariable) => hasDefaultStringType(e.dataType) case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cef0e3fc5261b..30c37b36e7eb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2127,7 +2127,7 @@ class AstBuilder extends DataTypeAstBuilder } val unresolvedTable = UnresolvedInlineTable(aliases, rows.toSeq) - val table = if (canEagerlyEvaluateInlineTable(unresolvedTable)) { + val table = if (canEagerlyEvaluateInlineTable(unresolvedTable, ctx)) { EvaluateUnresolvedInlineTable.evaluate(unresolvedTable) } else { unresolvedTable @@ -2138,9 +2138,36 @@ class AstBuilder extends DataTypeAstBuilder /** * Determines if the inline table can be eagerly evaluated. */ - private def canEagerlyEvaluateInlineTable(table: UnresolvedInlineTable): Boolean = { - conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED) && - ResolveInlineTables.canResolveInlineTable(table) + private def canEagerlyEvaluateInlineTable( + table: UnresolvedInlineTable, + ctx: InlineTableContext): Boolean = { + if (!conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + return false + } + + val isSessionCollationSet = conf.defaultStringType != StringType + + // if either of these are true we also need to check string types to make sure + // that there are no unresolved string types in the expressions + val checkStringTypes = isSessionCollationSet || contextInsideCreate(ctx) + + ResolveInlineTables.canResolveInlineTable(table, checkStringTypes) + } + + private def contextInsideCreate(ctx: ParserRuleContext): Boolean = { + var currentContext: RuleContext = ctx + + while (currentContext != null) { + if (currentContext.isInstanceOf[CreateTableContext] || + currentContext.isInstanceOf[ReplaceTableContext] || + currentContext.isInstanceOf[CreateViewContext]) { + return true + } + + currentContext = currentContext.parent + } + + false } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 8cc7cc5edd573..8442b9b43b867 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal, NamedLambdaVariable} import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.RuleExecutor.ONE_MORE_ITER +import org.apache.spark.sql.execution.command.CreateViewCommand import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StringType, StructType} @@ -86,7 +87,7 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic } private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { - case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true + case _: V2CreateTablePlan | _: CreateView | _: CreateViewCommand | _: AlterViewAs => true case _ => false } @@ -141,6 +142,9 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic case Literal(value, dt) if hasDefaultStringType(dt) => Literal(value, replaceDefaultStringType(dt, newType)) + case lambdaVar: NamedLambdaVariable if hasDefaultStringType(lambdaVar.dataType) => + lambdaVar.copy(dataType = replaceDefaultStringType(lambdaVar.dataType, newType)) + case other => other } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index 45ab1cdcff79e..0d5c414416d40 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -442,77 +442,77 @@ Project [array_except(array(collate(aaa, utf8_lcase)), array(collate(AAA, utf8_l -- !query select 'a' collate unicode < 'A' -- !query analysis -Project [(collate(a, unicode) < cast(A as string collate UNICODE)) AS (collate(a, unicode) < A)#x] +Project [(collate(a, unicode) < A) AS (collate(a, unicode) < 'A' collate UNICODE)#x] +- OneRowRelation -- !query select 'a' collate unicode_ci = 'A' -- !query analysis -Project [(collate(a, unicode_ci) = cast(A as string collate UNICODE_CI)) AS (collate(a, unicode_ci) = A)#x] +Project [(collate(a, unicode_ci) = A) AS (collate(a, unicode_ci) = 'A' collate UNICODE_CI)#x] +- OneRowRelation -- !query select 'a' collate unicode_ai = 'å' -- !query analysis -Project [(collate(a, unicode_ai) = cast(å as string collate UNICODE_AI)) AS (collate(a, unicode_ai) = å)#x] +Project [(collate(a, unicode_ai) = å) AS (collate(a, unicode_ai) = 'å' collate UNICODE_AI)#x] +- OneRowRelation -- !query select 'a' collate unicode_ci_ai = 'Å' -- !query analysis -Project [(collate(a, unicode_ci_ai) = cast(Å as string collate UNICODE_CI_AI)) AS (collate(a, unicode_ci_ai) = Å)#x] +Project [(collate(a, unicode_ci_ai) = Å) AS (collate(a, unicode_ci_ai) = 'Å' collate UNICODE_CI_AI)#x] +- OneRowRelation -- !query select 'a' collate en < 'A' -- !query analysis -Project [(collate(a, en) < cast(A as string collate en)) AS (collate(a, en) < A)#x] +Project [(collate(a, en) < A) AS (collate(a, en) < 'A' collate en)#x] +- OneRowRelation -- !query select 'a' collate en_ci = 'A' -- !query analysis -Project [(collate(a, en_ci) = cast(A as string collate en_CI)) AS (collate(a, en_ci) = A)#x] +Project [(collate(a, en_ci) = A) AS (collate(a, en_ci) = 'A' collate en_CI)#x] +- OneRowRelation -- !query select 'a' collate en_ai = 'å' -- !query analysis -Project [(collate(a, en_ai) = cast(å as string collate en_AI)) AS (collate(a, en_ai) = å)#x] +Project [(collate(a, en_ai) = å) AS (collate(a, en_ai) = 'å' collate en_AI)#x] +- OneRowRelation -- !query select 'a' collate en_ci_ai = 'Å' -- !query analysis -Project [(collate(a, en_ci_ai) = cast(Å as string collate en_CI_AI)) AS (collate(a, en_ci_ai) = Å)#x] +Project [(collate(a, en_ci_ai) = Å) AS (collate(a, en_ci_ai) = 'Å' collate en_CI_AI)#x] +- OneRowRelation -- !query select 'Kypper' collate sv < 'Köpfe' -- !query analysis -Project [(collate(Kypper, sv) < cast(Köpfe as string collate sv)) AS (collate(Kypper, sv) < Köpfe)#x] +Project [(collate(Kypper, sv) < Köpfe) AS (collate(Kypper, sv) < 'Köpfe' collate sv)#x] +- OneRowRelation -- !query select 'Kypper' collate de > 'Köpfe' -- !query analysis -Project [(collate(Kypper, de) > cast(Köpfe as string collate de)) AS (collate(Kypper, de) > Köpfe)#x] +Project [(collate(Kypper, de) > Köpfe) AS (collate(Kypper, de) > 'Köpfe' collate de)#x] +- OneRowRelation -- !query select 'I' collate tr_ci = 'ı' -- !query analysis -Project [(collate(I, tr_ci) = cast(ı as string collate tr_CI)) AS (collate(I, tr_ci) = ı)#x] +Project [(collate(I, tr_ci) = ı) AS (collate(I, tr_ci) = 'ı' collate tr_CI)#x] +- OneRowRelation @@ -826,7 +826,7 @@ InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_d -- !query select concat_ws(' ', utf8_lcase, utf8_lcase) from t5 -- !query analysis -Project [concat_ws(cast( as string collate UTF8_LCASE), utf8_lcase#x, utf8_lcase#x) AS concat_ws( , utf8_lcase, utf8_lcase)#x] +Project [concat_ws( , utf8_lcase#x, utf8_lcase#x) AS concat_ws( , utf8_lcase, utf8_lcase)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -868,7 +868,7 @@ Project [concat_ws(collate( , utf8_lcase), cast(utf8_binary#x as string collate -- !query select concat_ws(',', utf8_lcase, 'word'), concat_ws(',', utf8_binary, 'word') from t5 -- !query analysis -Project [concat_ws(cast(, as string collate UTF8_LCASE), utf8_lcase#x, cast(word as string collate UTF8_LCASE)) AS concat_ws(,, utf8_lcase, word)#x, concat_ws(,, utf8_binary#x, word) AS concat_ws(,, utf8_binary, word)#x] +Project [concat_ws(,, utf8_lcase#x, word) AS concat_ws(,, utf8_lcase, word)#x, concat_ws(,, utf8_binary#x, word) AS concat_ws(,, utf8_binary, word)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -876,7 +876,7 @@ Project [concat_ws(cast(, as string collate UTF8_LCASE), utf8_lcase#x, cast(word -- !query select concat_ws(',', utf8_lcase, 'word' collate utf8_binary), concat_ws(',', utf8_binary, 'word' collate utf8_lcase) from t5 -- !query analysis -Project [concat_ws(,, cast(utf8_lcase#x as string), collate(word, utf8_binary)) AS concat_ws(,, utf8_lcase, collate(word, utf8_binary))#x, concat_ws(cast(, as string collate UTF8_LCASE), cast(utf8_binary#x as string collate UTF8_LCASE), collate(word, utf8_lcase)) AS concat_ws(,, utf8_binary, collate(word, utf8_lcase))#x] +Project [concat_ws(,, cast(utf8_lcase#x as string), collate(word, utf8_binary)) AS concat_ws(,, utf8_lcase, collate(word, utf8_binary))#x, concat_ws(,, cast(utf8_binary#x as string collate UTF8_LCASE), collate(word, utf8_lcase)) AS concat_ws(,, utf8_binary, collate(word, utf8_lcase))#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -934,7 +934,7 @@ Project [elt(1, collate(utf8_binary#x, utf8_binary), cast(utf8_lcase#x as string -- !query select elt(1, utf8_binary, 'word'), elt(1, utf8_lcase, 'word') from t5 -- !query analysis -Project [elt(1, utf8_binary#x, word, true) AS elt(1, utf8_binary, word)#x, elt(1, utf8_lcase#x, cast(word as string collate UTF8_LCASE), true) AS elt(1, utf8_lcase, word)#x] +Project [elt(1, utf8_binary#x, word, true) AS elt(1, utf8_binary, word)#x, elt(1, utf8_lcase#x, word, true) AS elt(1, utf8_lcase, 'word' collate UTF8_LCASE)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1024,7 +1024,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5 -- !query analysis -Project [split_part(utf8_binary#x, a, 3) AS split_part(utf8_binary, a, 3)#x, split_part(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 3) AS split_part(utf8_lcase, a, 3)#x] +Project [split_part(utf8_binary#x, a, 3) AS split_part(utf8_binary, a, 3)#x, split_part(utf8_lcase#x, a, 3) AS split_part(utf8_lcase, a, 3)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1122,7 +1122,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5 -- !query analysis -Project [Contains(utf8_binary#x, a) AS contains(utf8_binary, a)#x, Contains(utf8_lcase#x, cast(a as string collate UTF8_LCASE)) AS contains(utf8_lcase, a)#x] +Project [Contains(utf8_binary#x, a) AS contains(utf8_binary, a)#x, Contains(utf8_lcase#x, a) AS contains(utf8_lcase, a)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1220,7 +1220,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5 -- !query analysis -Project [substring_index(utf8_binary#x, a, 2) AS substring_index(utf8_binary, a, 2)#x, substring_index(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 2) AS substring_index(utf8_lcase, a, 2)#x] +Project [substring_index(utf8_binary#x, a, 2) AS substring_index(utf8_binary, a, 2)#x, substring_index(utf8_lcase#x, a, 2) AS substring_index(utf8_lcase, a, 2)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1318,7 +1318,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5 -- !query analysis -Project [instr(utf8_binary#x, a) AS instr(utf8_binary, a)#x, instr(utf8_lcase#x, cast(a as string collate UTF8_LCASE)) AS instr(utf8_lcase, a)#x] +Project [instr(utf8_binary#x, a) AS instr(utf8_binary, a)#x, instr(utf8_lcase#x, a) AS instr(utf8_lcase, a)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1384,7 +1384,7 @@ Project [find_in_set(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, u -- !query select find_in_set(utf8_binary, 'aaAaaAaA,i̇o'), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o') from t5 -- !query analysis -Project [find_in_set(utf8_binary#x, aaAaaAaA,i̇o) AS find_in_set(utf8_binary, aaAaaAaA,i̇o)#x, find_in_set(utf8_lcase#x, cast(aaAaaAaA,i̇o as string collate UTF8_LCASE)) AS find_in_set(utf8_lcase, aaAaaAaA,i̇o)#x] +Project [find_in_set(utf8_binary#x, aaAaaAaA,i̇o) AS find_in_set(utf8_binary, aaAaaAaA,i̇o)#x, find_in_set(utf8_lcase#x, aaAaaAaA,i̇o) AS find_in_set(utf8_lcase, aaAaaAaA,i̇o)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1482,7 +1482,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query analysis -Project [StartsWith(utf8_binary#x, aaAaaAaA) AS startswith(utf8_binary, aaAaaAaA)#x, StartsWith(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE)) AS startswith(utf8_lcase, aaAaaAaA)#x] +Project [StartsWith(utf8_binary#x, aaAaaAaA) AS startswith(utf8_binary, aaAaaAaA)#x, StartsWith(utf8_lcase#x, aaAaaAaA) AS startswith(utf8_lcase, aaAaaAaA)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1506,7 +1506,7 @@ Project [StartsWith(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), coll -- !query select translate(utf8_lcase, utf8_lcase, '12345') from t5 -- !query analysis -Project [translate(utf8_lcase#x, utf8_lcase#x, cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_lcase, utf8_lcase, 12345)#x] +Project [translate(utf8_lcase#x, utf8_lcase#x, 12345) AS translate(utf8_lcase, utf8_lcase, 12345)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1572,7 +1572,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5 -- !query analysis -Project [translate(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE), cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_lcase, aaAaaAaA, 12345)#x, translate(utf8_binary#x, aaAaaAaA, 12345) AS translate(utf8_binary, aaAaaAaA, 12345)#x] +Project [translate(utf8_lcase#x, aaAaaAaA, 12345) AS translate(utf8_lcase, aaAaaAaA, 12345)#x, translate(utf8_binary#x, aaAaaAaA, 12345) AS translate(utf8_binary, aaAaaAaA, 12345)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1580,7 +1580,7 @@ Project [translate(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE), ca -- !query select translate(utf8_lcase, 'aBc' collate utf8_binary, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5 -- !query analysis -Project [translate(cast(utf8_lcase#x as string), collate(aBc, utf8_binary), 12345) AS translate(utf8_lcase, collate(aBc, utf8_binary), 12345)#x, translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aBc, utf8_lcase), cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_binary, collate(aBc, utf8_lcase), 12345)#x] +Project [translate(cast(utf8_lcase#x as string), collate(aBc, utf8_binary), 12345) AS translate(utf8_lcase, collate(aBc, utf8_binary), 12345)#x, translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aBc, utf8_lcase), 12345) AS translate(utf8_binary, collate(aBc, utf8_lcase), 12345)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1588,7 +1588,7 @@ Project [translate(cast(utf8_lcase#x as string), collate(aBc, utf8_binary), 1234 -- !query select translate(utf8_lcase, 'aBc ' collate utf8_binary_rtrim, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5 -- !query analysis -Project [translate(cast(utf8_lcase#x as string collate UTF8_BINARY_RTRIM), collate(aBc , utf8_binary_rtrim), cast(12345 as string collate UTF8_BINARY_RTRIM)) AS translate(utf8_lcase, collate(aBc , utf8_binary_rtrim), 12345)#x, translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aBc, utf8_lcase), cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_binary, collate(aBc, utf8_lcase), 12345)#x] +Project [translate(cast(utf8_lcase#x as string collate UTF8_BINARY_RTRIM), collate(aBc , utf8_binary_rtrim), 12345) AS translate(utf8_lcase, collate(aBc , utf8_binary_rtrim), 12345)#x, translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aBc, utf8_lcase), 12345) AS translate(utf8_binary, collate(aBc, utf8_lcase), 12345)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1638,7 +1638,7 @@ Project [replace(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), abc) AS repl -- !query select replace(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 'abc') from t5 -- !query analysis -Project [replace(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), cast(abc as string collate UTF8_LCASE)) AS replace(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), abc)#x] +Project [replace(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), abc) AS replace(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), abc)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1655,7 +1655,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputType" : "\"STRING COLLATE UNICODE_AI\"", "paramIndex" : "first", "requiredType" : "\"STRING\"", - "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\"" + "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 'abc' collate UNICODE_AI)\"" }, "queryContext" : [ { "objectType" : "", @@ -1670,7 +1670,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5 -- !query analysis -Project [replace(utf8_binary#x, aaAaaAaA, abc) AS replace(utf8_binary, aaAaaAaA, abc)#x, replace(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE), cast(abc as string collate UTF8_LCASE)) AS replace(utf8_lcase, aaAaaAaA, abc)#x] +Project [replace(utf8_binary#x, aaAaaAaA, abc) AS replace(utf8_binary, aaAaaAaA, abc)#x, replace(utf8_lcase#x, aaAaaAaA, abc) AS replace(utf8_lcase, aaAaaAaA, abc)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1678,7 +1678,7 @@ Project [replace(utf8_binary#x, aaAaaAaA, abc) AS replace(utf8_binary, aaAaaAaA, -- !query select replace(utf8_binary, 'aaAaaAaA' collate utf8_lcase, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5 -- !query analysis -Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaaAaA, utf8_lcase), cast(abc as string collate UTF8_LCASE)) AS replace(utf8_binary, collate(aaAaaAaA, utf8_lcase), abc)#x, replace(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary), abc) AS replace(utf8_lcase, collate(aaAaaAaA, utf8_binary), abc)#x] +Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaaAaA, utf8_lcase), abc) AS replace(utf8_binary, collate(aaAaaAaA, utf8_lcase), abc)#x, replace(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary), abc) AS replace(utf8_lcase, collate(aaAaaAaA, utf8_binary), abc)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1686,7 +1686,7 @@ Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaa -- !query select replace(utf8_binary, 'aaAaaAaA ' collate utf8_lcase_rtrim, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5 -- !query analysis -Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), collate(aaAaaAaA , utf8_lcase_rtrim), cast(abc as string collate UTF8_LCASE_RTRIM)) AS replace(utf8_binary, collate(aaAaaAaA , utf8_lcase_rtrim), abc)#x, replace(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary), abc) AS replace(utf8_lcase, collate(aaAaaAaA, utf8_binary), abc)#x] +Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), collate(aaAaaAaA , utf8_lcase_rtrim), abc) AS replace(utf8_binary, collate(aaAaaAaA , utf8_lcase_rtrim), abc)#x, replace(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary), abc) AS replace(utf8_lcase, collate(aaAaaAaA, utf8_binary), abc)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1768,7 +1768,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query analysis -Project [EndsWith(utf8_binary#x, aaAaaAaA) AS endswith(utf8_binary, aaAaaAaA)#x, EndsWith(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE)) AS endswith(utf8_lcase, aaAaaAaA)#x] +Project [EndsWith(utf8_binary#x, aaAaaAaA) AS endswith(utf8_binary, aaAaaAaA)#x, EndsWith(utf8_lcase#x, aaAaaAaA) AS endswith(utf8_lcase, aaAaaAaA)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2042,7 +2042,7 @@ Project [overlay(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_ -- !query select overlay(utf8_binary, 'a', 2), overlay(utf8_lcase, 'a', 2) from t5 -- !query analysis -Project [overlay(utf8_binary#x, a, 2, -1) AS overlay(utf8_binary, a, 2, -1)#x, overlay(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 2, -1) AS overlay(utf8_lcase, a, 2, -1)#x] +Project [overlay(utf8_binary#x, a, 2, -1) AS overlay(utf8_binary, a, 2, -1)#x, overlay(utf8_lcase#x, a, 2, -1) AS overlay(utf8_lcase, 'a' collate UTF8_LCASE, 2, -1)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2204,7 +2204,7 @@ Project [levenshtein(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, u -- !query select levenshtein(utf8_binary, 'a'), levenshtein(utf8_lcase, 'a') from t5 -- !query analysis -Project [levenshtein(utf8_binary#x, a, None) AS levenshtein(utf8_binary, a)#x, levenshtein(utf8_lcase#x, cast(a as string collate UTF8_LCASE), None) AS levenshtein(utf8_lcase, a)#x] +Project [levenshtein(utf8_binary#x, a, None) AS levenshtein(utf8_binary, a)#x, levenshtein(utf8_lcase#x, a, None) AS levenshtein(utf8_lcase, a)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2390,7 +2390,7 @@ Project [lpad(collate(utf8_binary#x, utf8_binary_rtrim), 8, collate(utf8_lcase#x -- !query select rpad(utf8_binary, 8, 'a'), rpad(utf8_lcase, 8, 'a') from t5 -- !query analysis -Project [rpad(utf8_binary#x, 8, a) AS rpad(utf8_binary, 8, a)#x, rpad(utf8_lcase#x, 8, cast(a as string collate UTF8_LCASE)) AS rpad(utf8_lcase, 8, a)#x] +Project [rpad(utf8_binary#x, 8, a) AS rpad(utf8_binary, 8, a)#x, rpad(utf8_lcase#x, 8, a) AS rpad(utf8_lcase, 8, a)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2464,7 +2464,7 @@ Project [lpad(collate(utf8_binary#x, utf8_binary_rtrim), 8, collate(utf8_lcase#x -- !query select lpad(utf8_binary, 8, 'a'), lpad(utf8_lcase, 8, 'a') from t5 -- !query analysis -Project [lpad(utf8_binary#x, 8, a) AS lpad(utf8_binary, 8, a)#x, lpad(utf8_lcase#x, 8, cast(a as string collate UTF8_LCASE)) AS lpad(utf8_lcase, 8, a)#x] +Project [lpad(utf8_binary#x, 8, a) AS lpad(utf8_binary, 8, a)#x, lpad(utf8_lcase#x, 8, a) AS lpad(utf8_lcase, 8, a)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2554,7 +2554,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5 -- !query analysis -Project [locate(utf8_binary#x, a, 1) AS locate(utf8_binary, a, 1)#x, locate(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 1) AS locate(utf8_lcase, a, 1)#x] +Project [locate(utf8_binary#x, a, 1) AS locate(utf8_binary, a, 1)#x, locate(utf8_lcase#x, a, 1) AS locate(utf8_lcase, a, 1)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2660,7 +2660,7 @@ Project [trim(collate(utf8_lcase#x, utf8_binary_rtrim), Some(collate(utf8_binary -- !query select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5 -- !query analysis -Project [trim(utf8_binary#x, Some(ABc)) AS TRIM(BOTH ABc FROM utf8_binary)#x, trim(utf8_lcase#x, Some(cast(ABc as string collate UTF8_LCASE))) AS TRIM(BOTH ABc FROM utf8_lcase)#x] +Project [trim(utf8_binary#x, Some(ABc)) AS TRIM(BOTH ABc FROM utf8_binary)#x, trim(utf8_lcase#x, Some(ABc)) AS TRIM(BOTH ABc FROM utf8_lcase)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2856,7 +2856,7 @@ Project [ltrim(collate(utf8_lcase#x, utf8_binary_rtrim), Some(collate(utf8_binar -- !query select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5 -- !query analysis -Project [ltrim(utf8_binary#x, Some(ABc)) AS TRIM(LEADING ABc FROM utf8_binary)#x, ltrim(utf8_lcase#x, Some(cast(ABc as string collate UTF8_LCASE))) AS TRIM(LEADING ABc FROM utf8_lcase)#x] +Project [ltrim(utf8_binary#x, Some(ABc)) AS TRIM(LEADING ABc FROM utf8_binary)#x, ltrim(utf8_lcase#x, Some(ABc)) AS TRIM(LEADING ABc FROM utf8_lcase)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2954,7 +2954,7 @@ Project [rtrim(collate(utf8_lcase#x, utf8_binary_rtrim), Some(collate(utf8_binar -- !query select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5 -- !query analysis -Project [rtrim(utf8_binary#x, Some(ABc)) AS TRIM(TRAILING ABc FROM utf8_binary)#x, rtrim(utf8_lcase#x, Some(cast(ABc as string collate UTF8_LCASE))) AS TRIM(TRAILING ABc FROM utf8_lcase)#x] +Project [rtrim(utf8_binary#x, Some(ABc)) AS TRIM(TRAILING ABc FROM utf8_binary)#x, rtrim(utf8_lcase#x, Some(ABc)) AS TRIM(TRAILING ABc FROM utf8_lcase)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out index 78539effe188e..988df7de1a3cf 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out @@ -115,7 +115,7 @@ org.apache.spark.sql.AnalysisException -- !query select * from values ("one", 2.0), ("two") as data(a, b) -- !query analysis -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", @@ -157,7 +157,7 @@ org.apache.spark.sql.AnalysisException -- !query select * from values ("one"), ("two") as data(a, b) -- !query analysis -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out index f042116182f7d..2333cce874d31 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out @@ -498,7 +498,7 @@ SELECT a, b, SUM(b) OVER(ORDER BY A ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b) -- !query analysis -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { "errorClass" : "INVALID_INLINE_TABLE.FAILED_SQL_EXPRESSION_EVALUATION", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index f92fc5de8c3f4..e96549f00d6ec 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -479,7 +479,7 @@ struct +struct<(collate(a, unicode) < 'A' collate UNICODE):boolean> -- !query output true @@ -487,7 +487,7 @@ true -- !query select 'a' collate unicode_ci = 'A' -- !query schema -struct<(collate(a, unicode_ci) = A):boolean> +struct<(collate(a, unicode_ci) = 'A' collate UNICODE_CI):boolean> -- !query output true @@ -495,7 +495,7 @@ true -- !query select 'a' collate unicode_ai = 'å' -- !query schema -struct<(collate(a, unicode_ai) = å):boolean> +struct<(collate(a, unicode_ai) = 'å' collate UNICODE_AI):boolean> -- !query output true @@ -503,7 +503,7 @@ true -- !query select 'a' collate unicode_ci_ai = 'Å' -- !query schema -struct<(collate(a, unicode_ci_ai) = Å):boolean> +struct<(collate(a, unicode_ci_ai) = 'Å' collate UNICODE_CI_AI):boolean> -- !query output true @@ -511,7 +511,7 @@ true -- !query select 'a' collate en < 'A' -- !query schema -struct<(collate(a, en) < A):boolean> +struct<(collate(a, en) < 'A' collate en):boolean> -- !query output true @@ -519,7 +519,7 @@ true -- !query select 'a' collate en_ci = 'A' -- !query schema -struct<(collate(a, en_ci) = A):boolean> +struct<(collate(a, en_ci) = 'A' collate en_CI):boolean> -- !query output true @@ -527,7 +527,7 @@ true -- !query select 'a' collate en_ai = 'å' -- !query schema -struct<(collate(a, en_ai) = å):boolean> +struct<(collate(a, en_ai) = 'å' collate en_AI):boolean> -- !query output true @@ -535,7 +535,7 @@ true -- !query select 'a' collate en_ci_ai = 'Å' -- !query schema -struct<(collate(a, en_ci_ai) = Å):boolean> +struct<(collate(a, en_ci_ai) = 'Å' collate en_CI_AI):boolean> -- !query output true @@ -543,7 +543,7 @@ true -- !query select 'Kypper' collate sv < 'Köpfe' -- !query schema -struct<(collate(Kypper, sv) < Köpfe):boolean> +struct<(collate(Kypper, sv) < 'Köpfe' collate sv):boolean> -- !query output true @@ -551,7 +551,7 @@ true -- !query select 'Kypper' collate de > 'Köpfe' -- !query schema -struct<(collate(Kypper, de) > Köpfe):boolean> +struct<(collate(Kypper, de) > 'Köpfe' collate de):boolean> -- !query output true @@ -559,7 +559,7 @@ true -- !query select 'I' collate tr_ci = 'ı' -- !query schema -struct<(collate(I, tr_ci) = ı):boolean> +struct<(collate(I, tr_ci) = 'ı' collate tr_CI):boolean> -- !query output true @@ -1109,7 +1109,7 @@ kitten -- !query select elt(1, utf8_binary, 'word'), elt(1, utf8_lcase, 'word') from t5 -- !query schema -struct +struct -- !query output Hello, world! Nice day. Hello, world! Nice day. Something else. Nothing here. Something else. Nothing here. @@ -2492,7 +2492,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputType" : "\"STRING COLLATE UNICODE_AI\"", "paramIndex" : "first", "requiredType" : "\"STRING\"", - "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\"" + "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 'abc' collate UNICODE_AI)\"" }, "queryContext" : [ { "objectType" : "", @@ -3342,7 +3342,7 @@ ksitTing -- !query select overlay(utf8_binary, 'a', 2), overlay(utf8_lcase, 'a', 2) from t5 -- !query schema -struct +struct -- !query output Hallo, world! Nice day. Hallo, world! Nice day. Saark SaL diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index 0a2c7b0f55ed2..4dcdf8ac3e980 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -131,7 +131,7 @@ select * from values ("one", 2.0), ("two") as data(a, b) -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", @@ -177,7 +177,7 @@ select * from values ("one"), ("two") as data(a, b) -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out index 2d539725b2a70..2085186dc8cfa 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out @@ -497,7 +497,7 @@ FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b) -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { "errorClass" : "INVALID_INLINE_TABLE.FAILED_SQL_EXPRESSION_EVALUATION", "sqlState" : "42000", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala index 57c4441b165ac..cd713016a96e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala @@ -452,7 +452,8 @@ class CollationSQLRegexpSuite }, condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( - "sqlExpr" -> "\"regexp_replace(collate(ABCDE, UNICODE_CI), .c., FFF, 1)\"", + "sqlExpr" -> ("\"regexp_replace(collate(ABCDE, UNICODE_CI), " + + ".c., 'FFF' collate UNICODE_CI, 1)\""), "paramIndex" -> "first", "inputSql" -> "\"collate(ABCDE, UNICODE_CI)\"", "inputType" -> "\"STRING COLLATE UNICODE_CI\"", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index 9d934763ccb87..9a6b645b44a48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -470,4 +470,18 @@ class DefaultCollationTestSuiteV2 extends DefaultCollationTestSuite with Datasou checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable"), Seq(Row(0))) } } + + test("inline table in RTAS") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |REPLACE TABLE $testTable AS + |SELECT * + |FROM (VALUES ('a', 'a' = 'A')) + |AS inline_table(c1, c2); + |""".stripMargin) + + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(0))) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 9f1954cbf6868..1f2e0780f43e7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -858,7 +858,7 @@ private[hive] trait HiveInspectors { * @return Hive java objectinspector (recursively). */ def toInspector(expr: Expression): ObjectInspector = expr match { - case Literal(value, StringType) => + case Literal(value, _: StringType) => getStringWritableConstantObjectInspector(value) case Literal(value, IntegerType) => getIntWritableConstantObjectInspector(value) @@ -1147,7 +1147,7 @@ private[hive] trait HiveInspectors { case IntegerType => intTypeInfo case LongType => longTypeInfo case ShortType => shortTypeInfo - case StringType => stringTypeInfo + case _: StringType => stringTypeInfo case d: DecimalType => decimalTypeInfo(d) case DateType => dateTypeInfo case TimestampType => timestampTypeInfo From 95af604ef1ff1abfb2438a141b7d2e4482ef4b22 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 26 Nov 2024 00:19:28 +0100 Subject: [PATCH 39/54] init --- .../spark/sql/catalyst/analysis/Analyzer.scala | 1 + .../analysis/ResolveDefaultStringTypes.scala | 17 ++++------------- .../catalyst/analysis/ResolveInlineTables.scala | 14 ++------------ .../spark/sql/catalyst/parser/AstBuilder.scala | 2 +- .../sql/catalyst/plans/logical/v2Commands.scala | 2 ++ .../spark/sql/execution/datasources/ddl.scala | 4 ++-- .../sql/internal/BaseSessionStateBuilder.scala | 3 +-- .../collation/DefaultCollationTestSuite.scala | 1 + .../sql/hive/HiveSessionStateBuilder.scala | 3 +-- 9 files changed, 15 insertions(+), 32 deletions(-) rename sql/{core => catalyst}/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala (90%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bed7bea61597f..3e4db4c39661e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -319,6 +319,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: + ResolveDefaultStringTypes :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala similarity index 90% rename from sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 8442b9b43b867..03819404290c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -18,13 +18,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal, NamedLambdaVariable} -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.RuleExecutor.ONE_MORE_ITER -import org.apache.spark.sql.execution.command.CreateViewCommand -import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StringType, StructType} +import org.apache.spark.sql.types.{DataType, StringType} /** * Resolves default string types in DDL commands. For DML commands, the default string type is @@ -82,12 +80,12 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic } private def isDDLCommand(plan: LogicalPlan): Boolean = plan exists { - case _: CreateTable | _: AddColumns | _: ReplaceColumns | _: AlterColumn => true + case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true case _ => isCreateOrAlterPlan(plan) } private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { - case _: V2CreateTablePlan | _: CreateView | _: CreateViewCommand | _: AlterViewAs => true + case _: V1CreateTablePlan | _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true case _ => false } @@ -95,12 +93,6 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic val newType = stringTypeForDDLCommand(plan) plan resolveOperators { - case createTable: CreateTable => - val newSchema = replaceDefaultStringType(createTable.tableDesc.schema, newType) - .asInstanceOf[StructType] - val withNewSchema = createTable.copy(createTable.tableDesc.copy(schema = newSchema)) - transformPlan(withNewSchema, newType) - case p if isCreateOrAlterPlan(p) => transformPlan(p, newType) @@ -205,4 +197,3 @@ case class TemporaryStringType(override val collationId: Int) extends StringType(collationId) { override def toString: String = s"TemporaryStringType($collationId)" } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index ff781f32a0ab1..62f3997491c07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.EvaluateUnresolvedInlineTable -import org.apache.spark.sql.catalyst.expressions.{EvalHelper, Expression, ExprUtils} +import org.apache.spark.sql.catalyst.expressions.EvalHelper import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.AlwaysProcess @@ -29,18 +29,8 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess object ResolveInlineTables extends Rule[LogicalPlan] with EvalHelper { override def apply(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsWithPruning(AlwaysProcess.fn, ruleId) { - case table: UnresolvedInlineTable if canResolveInlineTable(table) => + case table: UnresolvedInlineTable if table.expressionsResolved => EvaluateUnresolvedInlineTable.evaluateUnresolvedInlineTable(table) } } - - def canResolveInlineTable( - table: UnresolvedInlineTable, - checkStringTypes: Boolean = true): Boolean = { - table.expressionsResolved && - (!checkStringTypes || !table.expressions.exists(hasUnresolvedStringType)) - } - - private def hasUnresolvedStringType(expression: Expression): Boolean = - ExprUtils.hasExpressionWithDefaultStringType(expression) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 30c37b36e7eb6..f48548bde5f13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2151,7 +2151,7 @@ class AstBuilder extends DataTypeAstBuilder // that there are no unresolved string types in the expressions val checkStringTypes = isSessionCollationSet || contextInsideCreate(ctx) - ResolveInlineTables.canResolveInlineTable(table, checkStringTypes) + !checkStringTypes } private def contextInsideCreate(ctx: ParserRuleContext): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index b465e0e11612f..bffefbf2f4fea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -459,6 +459,8 @@ trait V2CreateTableAsSelectPlan newQuery: LogicalPlan): V2CreateTableAsSelectPlan } +trait V1CreateTablePlan extends LogicalPlan {} + /** A trait used for logical plan nodes that create or replace V2 table definitions. */ trait V2CreateTablePlan extends LogicalPlan { def name: LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index d9367d92d462e..eb9d5813cff7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V1CreateTablePlan} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._ case class CreateTable( tableDesc: CatalogTable, mode: SaveMode, - query: Option[LogicalPlan]) extends LogicalPlan { + query: Option[LogicalPlan]) extends LogicalPlan with V1CreateTablePlan { assert(tableDesc.provider.isDefined, "The table to be created must have a provider.") if (query.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 8119c9d7b203a..4bf08aef5611a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable import org.apache.spark.sql._ import org.apache.spark.sql.artifact.ArtifactManager -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringTypes, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -204,7 +204,6 @@ abstract class BaseSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: - ResolveDefaultStringTypes +: new ResolveSessionCatalog(this.catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index 9a6b645b44a48..ed3fac2311704 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -473,6 +473,7 @@ class DefaultCollationTestSuiteV2 extends DefaultCollationTestSuite with Datasou test("inline table in RTAS") { withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING, c2 BOOLEAN) USING $dataSource") sql(s""" |REPLACE TABLE $testTable AS |SELECT * diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index e42d01a8d7264..dbeb8607facc2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveDefaultStringTypes, ResolveSessionCatalog, ResolveTranspose} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -90,7 +90,6 @@ class HiveSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: - ResolveDefaultStringTypes +: new ResolveSessionCatalog(catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: From ba4defb83c9d59247ba2cd2d501e1456b64c46aa Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 26 Nov 2024 00:31:27 +0100 Subject: [PATCH 40/54] improvements --- .../analysis/CollationTypeCoercion.scala | 4 --- .../analysis/ResolveDefaultStringTypes.scala | 11 +------- .../sql/catalyst/expressions/ExprUtils.scala | 25 +------------------ .../sql/catalyst/rules/RuleExecutor.scala | 14 ++--------- .../apache/spark/sql/util/SchemaUtils.scala | 8 ------ .../internal/BaseSessionStateBuilder.scala | 2 +- .../spark/sql/hive/HiveInspectors.scala | 4 +-- 7 files changed, 7 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index 03701a5e5f0aa..c72a70c4ce0c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -219,10 +219,6 @@ object CollationTypeCoercion { if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) { // if there are no collated types we don't need to do anything return None - } else if (expressions.exists(ExprUtils.hasExpressionWithDefaultStringType)) { - // if there are default string types, return None as they need to be resolved - // first in [[ResolveDefaultStringTypes]] - return None } val collationContextWinner = expressions.foldLeft(findCollationContext(expressions.head)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 03819404290c0..8c2baaa6763a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal, NamedLambdaVariable} import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.rules.RuleExecutor.ONE_MORE_ITER import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StringType} @@ -44,7 +43,7 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic transformPlan(plan, newType) } - val finalPlan = if (!replaceWithTempType || newPlan.fastEquals(plan)) { + if (!replaceWithTempType || newPlan.fastEquals(plan)) { newPlan } else { // Due to how tree transformations work and StringType object being equal to @@ -52,14 +51,6 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic // to ensure the correct results for occurrences of default string type. ResolveDefaultStringTypesWithoutTempType.apply(newPlan) } - - if (finalPlan == plan && finalPlan == newPlan) { - finalPlan.unsetTagValue(ONE_MORE_ITER) - } else { - finalPlan.setTagValue(ONE_MORE_ITER, ()) - } - - finalPlan } private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 035fc32d27bfc..e65a0200b064f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnDefinition} +import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCollation} @@ -223,27 +223,4 @@ object ExprUtils extends EvalHelper with QueryErrorsBase { a.groupingExpressions.foreach(checkValidGroupingExprs) a.aggregateExpressions.foreach(checkValidAggregateExpression) } - - /** - * Returns whether any of the expression in the expression tree of the input - * has a default string type which should be resolved first. - */ - def hasExpressionWithDefaultStringType(expression: Expression): Boolean = { - expression.exists { - case e @ (_: Literal | _: ColumnDefinition | _: Cast | _: NamedLambdaVariable) => - hasDefaultStringType(e.dataType) - case _ => - false - } - } - - /** - * Returns whether the input data type has a default string type which should be resolved first. - */ - def hasDefaultStringType(dataType: DataType): Boolean = - dataType.existsRecursively { - case st: StringType => - st.eq(StringType) - case _ => false - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 483d8cab3270e..3fcf9c13005e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.catalyst.rules -import scala.collection.mutable.ListBuffer - import org.apache.spark.SparkException import org.apache.spark.internal.{Logging, MessageWithContext} import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.MDC import org.apache.spark.sql.catalyst.QueryPlanningTracker -import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.errors.QueryExecutionErrors @@ -32,7 +30,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils object RuleExecutor { - val ONE_MORE_ITER: TreeNodeTag[Unit] = TreeNodeTag[Unit]("one_more_iter") protected val queryExecutionMeter = QueryExecutionMetering() @@ -240,8 +237,6 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { // Run until fix point or the max number of iterations as specified in the strategy. while (continue) { - val effectiveRules = new ListBuffer[Rule[TreeType]]() - curPlan = batch.rules.foldLeft(curPlan) { case (plan, rule) => val startTime = System.nanoTime() @@ -252,7 +247,6 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (effective) { queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) - effectiveRules.addOne(rule) planChangeLogger.logRule(rule.ruleName, plan, result) // Run the plan changes validation after each rule. if (fullValidation || lightweightValidation) { @@ -310,7 +304,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { continue = false } - if (reachedFixedPoint(lastPlan, curPlan)) { + if (curPlan.fastEquals(lastPlan)) { logTrace( s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.") continue = false @@ -324,8 +318,4 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { curPlan } - - private def reachedFixedPoint(oldPlan: TreeType, newPlan: TreeType): Boolean = { - oldPlan.fastEquals(newPlan) && newPlan.getTagValue(RuleExecutor.ONE_MORE_ITER).isEmpty - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index 4c2e0a2bee9eb..1e0bac331dc75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -304,14 +304,6 @@ private[spark] object SchemaUtils { } } - /** - * Checks if a given data type has a default string type (StringType object), - * which should be resolved to a specific string type in the analyzer. - */ - def hasDefaultStringType(dt: DataType): Boolean = { - dt.existsRecursively(_.eq(StringType)) - } - /** * Replaces any collated string type with non collated StringType * recursively in the given data type. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 4bf08aef5611a..59a873ef982fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable -import org.apache.spark.sql._ +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.artifact.ArtifactManager import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 1f2e0780f43e7..9f1954cbf6868 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -858,7 +858,7 @@ private[hive] trait HiveInspectors { * @return Hive java objectinspector (recursively). */ def toInspector(expr: Expression): ObjectInspector = expr match { - case Literal(value, _: StringType) => + case Literal(value, StringType) => getStringWritableConstantObjectInspector(value) case Literal(value, IntegerType) => getIntWritableConstantObjectInspector(value) @@ -1147,7 +1147,7 @@ private[hive] trait HiveInspectors { case IntegerType => intTypeInfo case LongType => longTypeInfo case ShortType => shortTypeInfo - case _: StringType => stringTypeInfo + case StringType => stringTypeInfo case d: DecimalType => decimalTypeInfo(d) case DateType => dateTypeInfo case TimestampType => timestampTypeInfo From 999b335fd1c602bb25f3c5852d3489d587ce746c Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 26 Nov 2024 09:01:35 +0100 Subject: [PATCH 41/54] golden files --- .../analysis/CollationTypeCoercion.scala | 11 +-- .../analyzer-results/collations.sql.out | 78 +++++++++---------- .../analyzer-results/inline-table.sql.out | 4 +- .../postgreSQL/window_part4.sql.out | 2 +- .../sql-tests/results/collations.sql.out | 28 +++---- .../sql-tests/results/inline-table.sql.out | 4 +- .../results/postgreSQL/window_part4.sql.out | 2 +- 7 files changed, 62 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index c72a70c4ce0c3..e87fa4860fe88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -178,14 +178,9 @@ object CollationTypeCoercion { * if expression has StringType in the first place. */ def castStringType(expr: Expression, st: StringType): Expression = { - castStringType(expr.dataType, st) match { - case Some(dt) => expr match { - case lit: Literal => lit.copy(dataType = dt) - case cast: Cast => cast.copy(dataType = dt) - case _ => Cast(expr, dt) - } - case _ => expr - } + castStringType(expr.dataType, st) + .map(dt => Cast(expr, dt)) + .getOrElse(expr) } private def castStringType(inType: DataType, castType: StringType): Option[DataType] = { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index 0d5c414416d40..45ab1cdcff79e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -442,77 +442,77 @@ Project [array_except(array(collate(aaa, utf8_lcase)), array(collate(AAA, utf8_l -- !query select 'a' collate unicode < 'A' -- !query analysis -Project [(collate(a, unicode) < A) AS (collate(a, unicode) < 'A' collate UNICODE)#x] +Project [(collate(a, unicode) < cast(A as string collate UNICODE)) AS (collate(a, unicode) < A)#x] +- OneRowRelation -- !query select 'a' collate unicode_ci = 'A' -- !query analysis -Project [(collate(a, unicode_ci) = A) AS (collate(a, unicode_ci) = 'A' collate UNICODE_CI)#x] +Project [(collate(a, unicode_ci) = cast(A as string collate UNICODE_CI)) AS (collate(a, unicode_ci) = A)#x] +- OneRowRelation -- !query select 'a' collate unicode_ai = 'å' -- !query analysis -Project [(collate(a, unicode_ai) = å) AS (collate(a, unicode_ai) = 'å' collate UNICODE_AI)#x] +Project [(collate(a, unicode_ai) = cast(å as string collate UNICODE_AI)) AS (collate(a, unicode_ai) = å)#x] +- OneRowRelation -- !query select 'a' collate unicode_ci_ai = 'Å' -- !query analysis -Project [(collate(a, unicode_ci_ai) = Å) AS (collate(a, unicode_ci_ai) = 'Å' collate UNICODE_CI_AI)#x] +Project [(collate(a, unicode_ci_ai) = cast(Å as string collate UNICODE_CI_AI)) AS (collate(a, unicode_ci_ai) = Å)#x] +- OneRowRelation -- !query select 'a' collate en < 'A' -- !query analysis -Project [(collate(a, en) < A) AS (collate(a, en) < 'A' collate en)#x] +Project [(collate(a, en) < cast(A as string collate en)) AS (collate(a, en) < A)#x] +- OneRowRelation -- !query select 'a' collate en_ci = 'A' -- !query analysis -Project [(collate(a, en_ci) = A) AS (collate(a, en_ci) = 'A' collate en_CI)#x] +Project [(collate(a, en_ci) = cast(A as string collate en_CI)) AS (collate(a, en_ci) = A)#x] +- OneRowRelation -- !query select 'a' collate en_ai = 'å' -- !query analysis -Project [(collate(a, en_ai) = å) AS (collate(a, en_ai) = 'å' collate en_AI)#x] +Project [(collate(a, en_ai) = cast(å as string collate en_AI)) AS (collate(a, en_ai) = å)#x] +- OneRowRelation -- !query select 'a' collate en_ci_ai = 'Å' -- !query analysis -Project [(collate(a, en_ci_ai) = Å) AS (collate(a, en_ci_ai) = 'Å' collate en_CI_AI)#x] +Project [(collate(a, en_ci_ai) = cast(Å as string collate en_CI_AI)) AS (collate(a, en_ci_ai) = Å)#x] +- OneRowRelation -- !query select 'Kypper' collate sv < 'Köpfe' -- !query analysis -Project [(collate(Kypper, sv) < Köpfe) AS (collate(Kypper, sv) < 'Köpfe' collate sv)#x] +Project [(collate(Kypper, sv) < cast(Köpfe as string collate sv)) AS (collate(Kypper, sv) < Köpfe)#x] +- OneRowRelation -- !query select 'Kypper' collate de > 'Köpfe' -- !query analysis -Project [(collate(Kypper, de) > Köpfe) AS (collate(Kypper, de) > 'Köpfe' collate de)#x] +Project [(collate(Kypper, de) > cast(Köpfe as string collate de)) AS (collate(Kypper, de) > Köpfe)#x] +- OneRowRelation -- !query select 'I' collate tr_ci = 'ı' -- !query analysis -Project [(collate(I, tr_ci) = ı) AS (collate(I, tr_ci) = 'ı' collate tr_CI)#x] +Project [(collate(I, tr_ci) = cast(ı as string collate tr_CI)) AS (collate(I, tr_ci) = ı)#x] +- OneRowRelation @@ -826,7 +826,7 @@ InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_d -- !query select concat_ws(' ', utf8_lcase, utf8_lcase) from t5 -- !query analysis -Project [concat_ws( , utf8_lcase#x, utf8_lcase#x) AS concat_ws( , utf8_lcase, utf8_lcase)#x] +Project [concat_ws(cast( as string collate UTF8_LCASE), utf8_lcase#x, utf8_lcase#x) AS concat_ws( , utf8_lcase, utf8_lcase)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -868,7 +868,7 @@ Project [concat_ws(collate( , utf8_lcase), cast(utf8_binary#x as string collate -- !query select concat_ws(',', utf8_lcase, 'word'), concat_ws(',', utf8_binary, 'word') from t5 -- !query analysis -Project [concat_ws(,, utf8_lcase#x, word) AS concat_ws(,, utf8_lcase, word)#x, concat_ws(,, utf8_binary#x, word) AS concat_ws(,, utf8_binary, word)#x] +Project [concat_ws(cast(, as string collate UTF8_LCASE), utf8_lcase#x, cast(word as string collate UTF8_LCASE)) AS concat_ws(,, utf8_lcase, word)#x, concat_ws(,, utf8_binary#x, word) AS concat_ws(,, utf8_binary, word)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -876,7 +876,7 @@ Project [concat_ws(,, utf8_lcase#x, word) AS concat_ws(,, utf8_lcase, word)#x, c -- !query select concat_ws(',', utf8_lcase, 'word' collate utf8_binary), concat_ws(',', utf8_binary, 'word' collate utf8_lcase) from t5 -- !query analysis -Project [concat_ws(,, cast(utf8_lcase#x as string), collate(word, utf8_binary)) AS concat_ws(,, utf8_lcase, collate(word, utf8_binary))#x, concat_ws(,, cast(utf8_binary#x as string collate UTF8_LCASE), collate(word, utf8_lcase)) AS concat_ws(,, utf8_binary, collate(word, utf8_lcase))#x] +Project [concat_ws(,, cast(utf8_lcase#x as string), collate(word, utf8_binary)) AS concat_ws(,, utf8_lcase, collate(word, utf8_binary))#x, concat_ws(cast(, as string collate UTF8_LCASE), cast(utf8_binary#x as string collate UTF8_LCASE), collate(word, utf8_lcase)) AS concat_ws(,, utf8_binary, collate(word, utf8_lcase))#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -934,7 +934,7 @@ Project [elt(1, collate(utf8_binary#x, utf8_binary), cast(utf8_lcase#x as string -- !query select elt(1, utf8_binary, 'word'), elt(1, utf8_lcase, 'word') from t5 -- !query analysis -Project [elt(1, utf8_binary#x, word, true) AS elt(1, utf8_binary, word)#x, elt(1, utf8_lcase#x, word, true) AS elt(1, utf8_lcase, 'word' collate UTF8_LCASE)#x] +Project [elt(1, utf8_binary#x, word, true) AS elt(1, utf8_binary, word)#x, elt(1, utf8_lcase#x, cast(word as string collate UTF8_LCASE), true) AS elt(1, utf8_lcase, word)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1024,7 +1024,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5 -- !query analysis -Project [split_part(utf8_binary#x, a, 3) AS split_part(utf8_binary, a, 3)#x, split_part(utf8_lcase#x, a, 3) AS split_part(utf8_lcase, a, 3)#x] +Project [split_part(utf8_binary#x, a, 3) AS split_part(utf8_binary, a, 3)#x, split_part(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 3) AS split_part(utf8_lcase, a, 3)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1122,7 +1122,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5 -- !query analysis -Project [Contains(utf8_binary#x, a) AS contains(utf8_binary, a)#x, Contains(utf8_lcase#x, a) AS contains(utf8_lcase, a)#x] +Project [Contains(utf8_binary#x, a) AS contains(utf8_binary, a)#x, Contains(utf8_lcase#x, cast(a as string collate UTF8_LCASE)) AS contains(utf8_lcase, a)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1220,7 +1220,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5 -- !query analysis -Project [substring_index(utf8_binary#x, a, 2) AS substring_index(utf8_binary, a, 2)#x, substring_index(utf8_lcase#x, a, 2) AS substring_index(utf8_lcase, a, 2)#x] +Project [substring_index(utf8_binary#x, a, 2) AS substring_index(utf8_binary, a, 2)#x, substring_index(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 2) AS substring_index(utf8_lcase, a, 2)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1318,7 +1318,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5 -- !query analysis -Project [instr(utf8_binary#x, a) AS instr(utf8_binary, a)#x, instr(utf8_lcase#x, a) AS instr(utf8_lcase, a)#x] +Project [instr(utf8_binary#x, a) AS instr(utf8_binary, a)#x, instr(utf8_lcase#x, cast(a as string collate UTF8_LCASE)) AS instr(utf8_lcase, a)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1384,7 +1384,7 @@ Project [find_in_set(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, u -- !query select find_in_set(utf8_binary, 'aaAaaAaA,i̇o'), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o') from t5 -- !query analysis -Project [find_in_set(utf8_binary#x, aaAaaAaA,i̇o) AS find_in_set(utf8_binary, aaAaaAaA,i̇o)#x, find_in_set(utf8_lcase#x, aaAaaAaA,i̇o) AS find_in_set(utf8_lcase, aaAaaAaA,i̇o)#x] +Project [find_in_set(utf8_binary#x, aaAaaAaA,i̇o) AS find_in_set(utf8_binary, aaAaaAaA,i̇o)#x, find_in_set(utf8_lcase#x, cast(aaAaaAaA,i̇o as string collate UTF8_LCASE)) AS find_in_set(utf8_lcase, aaAaaAaA,i̇o)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1482,7 +1482,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query analysis -Project [StartsWith(utf8_binary#x, aaAaaAaA) AS startswith(utf8_binary, aaAaaAaA)#x, StartsWith(utf8_lcase#x, aaAaaAaA) AS startswith(utf8_lcase, aaAaaAaA)#x] +Project [StartsWith(utf8_binary#x, aaAaaAaA) AS startswith(utf8_binary, aaAaaAaA)#x, StartsWith(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE)) AS startswith(utf8_lcase, aaAaaAaA)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1506,7 +1506,7 @@ Project [StartsWith(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), coll -- !query select translate(utf8_lcase, utf8_lcase, '12345') from t5 -- !query analysis -Project [translate(utf8_lcase#x, utf8_lcase#x, 12345) AS translate(utf8_lcase, utf8_lcase, 12345)#x] +Project [translate(utf8_lcase#x, utf8_lcase#x, cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_lcase, utf8_lcase, 12345)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1572,7 +1572,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5 -- !query analysis -Project [translate(utf8_lcase#x, aaAaaAaA, 12345) AS translate(utf8_lcase, aaAaaAaA, 12345)#x, translate(utf8_binary#x, aaAaaAaA, 12345) AS translate(utf8_binary, aaAaaAaA, 12345)#x] +Project [translate(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE), cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_lcase, aaAaaAaA, 12345)#x, translate(utf8_binary#x, aaAaaAaA, 12345) AS translate(utf8_binary, aaAaaAaA, 12345)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1580,7 +1580,7 @@ Project [translate(utf8_lcase#x, aaAaaAaA, 12345) AS translate(utf8_lcase, aaAaa -- !query select translate(utf8_lcase, 'aBc' collate utf8_binary, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5 -- !query analysis -Project [translate(cast(utf8_lcase#x as string), collate(aBc, utf8_binary), 12345) AS translate(utf8_lcase, collate(aBc, utf8_binary), 12345)#x, translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aBc, utf8_lcase), 12345) AS translate(utf8_binary, collate(aBc, utf8_lcase), 12345)#x] +Project [translate(cast(utf8_lcase#x as string), collate(aBc, utf8_binary), 12345) AS translate(utf8_lcase, collate(aBc, utf8_binary), 12345)#x, translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aBc, utf8_lcase), cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_binary, collate(aBc, utf8_lcase), 12345)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1588,7 +1588,7 @@ Project [translate(cast(utf8_lcase#x as string), collate(aBc, utf8_binary), 1234 -- !query select translate(utf8_lcase, 'aBc ' collate utf8_binary_rtrim, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5 -- !query analysis -Project [translate(cast(utf8_lcase#x as string collate UTF8_BINARY_RTRIM), collate(aBc , utf8_binary_rtrim), 12345) AS translate(utf8_lcase, collate(aBc , utf8_binary_rtrim), 12345)#x, translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aBc, utf8_lcase), 12345) AS translate(utf8_binary, collate(aBc, utf8_lcase), 12345)#x] +Project [translate(cast(utf8_lcase#x as string collate UTF8_BINARY_RTRIM), collate(aBc , utf8_binary_rtrim), cast(12345 as string collate UTF8_BINARY_RTRIM)) AS translate(utf8_lcase, collate(aBc , utf8_binary_rtrim), 12345)#x, translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aBc, utf8_lcase), cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_binary, collate(aBc, utf8_lcase), 12345)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1638,7 +1638,7 @@ Project [replace(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), abc) AS repl -- !query select replace(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 'abc') from t5 -- !query analysis -Project [replace(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), abc) AS replace(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), abc)#x] +Project [replace(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), cast(abc as string collate UTF8_LCASE)) AS replace(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), abc)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1655,7 +1655,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputType" : "\"STRING COLLATE UNICODE_AI\"", "paramIndex" : "first", "requiredType" : "\"STRING\"", - "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 'abc' collate UNICODE_AI)\"" + "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\"" }, "queryContext" : [ { "objectType" : "", @@ -1670,7 +1670,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5 -- !query analysis -Project [replace(utf8_binary#x, aaAaaAaA, abc) AS replace(utf8_binary, aaAaaAaA, abc)#x, replace(utf8_lcase#x, aaAaaAaA, abc) AS replace(utf8_lcase, aaAaaAaA, abc)#x] +Project [replace(utf8_binary#x, aaAaaAaA, abc) AS replace(utf8_binary, aaAaaAaA, abc)#x, replace(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE), cast(abc as string collate UTF8_LCASE)) AS replace(utf8_lcase, aaAaaAaA, abc)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1678,7 +1678,7 @@ Project [replace(utf8_binary#x, aaAaaAaA, abc) AS replace(utf8_binary, aaAaaAaA, -- !query select replace(utf8_binary, 'aaAaaAaA' collate utf8_lcase, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5 -- !query analysis -Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaaAaA, utf8_lcase), abc) AS replace(utf8_binary, collate(aaAaaAaA, utf8_lcase), abc)#x, replace(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary), abc) AS replace(utf8_lcase, collate(aaAaaAaA, utf8_binary), abc)#x] +Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaaAaA, utf8_lcase), cast(abc as string collate UTF8_LCASE)) AS replace(utf8_binary, collate(aaAaaAaA, utf8_lcase), abc)#x, replace(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary), abc) AS replace(utf8_lcase, collate(aaAaaAaA, utf8_binary), abc)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1686,7 +1686,7 @@ Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaa -- !query select replace(utf8_binary, 'aaAaaAaA ' collate utf8_lcase_rtrim, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5 -- !query analysis -Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), collate(aaAaaAaA , utf8_lcase_rtrim), abc) AS replace(utf8_binary, collate(aaAaaAaA , utf8_lcase_rtrim), abc)#x, replace(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary), abc) AS replace(utf8_lcase, collate(aaAaaAaA, utf8_binary), abc)#x] +Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), collate(aaAaaAaA , utf8_lcase_rtrim), cast(abc as string collate UTF8_LCASE_RTRIM)) AS replace(utf8_binary, collate(aaAaaAaA , utf8_lcase_rtrim), abc)#x, replace(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary), abc) AS replace(utf8_lcase, collate(aaAaaAaA, utf8_binary), abc)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -1768,7 +1768,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query analysis -Project [EndsWith(utf8_binary#x, aaAaaAaA) AS endswith(utf8_binary, aaAaaAaA)#x, EndsWith(utf8_lcase#x, aaAaaAaA) AS endswith(utf8_lcase, aaAaaAaA)#x] +Project [EndsWith(utf8_binary#x, aaAaaAaA) AS endswith(utf8_binary, aaAaaAaA)#x, EndsWith(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE)) AS endswith(utf8_lcase, aaAaaAaA)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2042,7 +2042,7 @@ Project [overlay(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_ -- !query select overlay(utf8_binary, 'a', 2), overlay(utf8_lcase, 'a', 2) from t5 -- !query analysis -Project [overlay(utf8_binary#x, a, 2, -1) AS overlay(utf8_binary, a, 2, -1)#x, overlay(utf8_lcase#x, a, 2, -1) AS overlay(utf8_lcase, 'a' collate UTF8_LCASE, 2, -1)#x] +Project [overlay(utf8_binary#x, a, 2, -1) AS overlay(utf8_binary, a, 2, -1)#x, overlay(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 2, -1) AS overlay(utf8_lcase, a, 2, -1)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2204,7 +2204,7 @@ Project [levenshtein(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, u -- !query select levenshtein(utf8_binary, 'a'), levenshtein(utf8_lcase, 'a') from t5 -- !query analysis -Project [levenshtein(utf8_binary#x, a, None) AS levenshtein(utf8_binary, a)#x, levenshtein(utf8_lcase#x, a, None) AS levenshtein(utf8_lcase, a)#x] +Project [levenshtein(utf8_binary#x, a, None) AS levenshtein(utf8_binary, a)#x, levenshtein(utf8_lcase#x, cast(a as string collate UTF8_LCASE), None) AS levenshtein(utf8_lcase, a)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2390,7 +2390,7 @@ Project [lpad(collate(utf8_binary#x, utf8_binary_rtrim), 8, collate(utf8_lcase#x -- !query select rpad(utf8_binary, 8, 'a'), rpad(utf8_lcase, 8, 'a') from t5 -- !query analysis -Project [rpad(utf8_binary#x, 8, a) AS rpad(utf8_binary, 8, a)#x, rpad(utf8_lcase#x, 8, a) AS rpad(utf8_lcase, 8, a)#x] +Project [rpad(utf8_binary#x, 8, a) AS rpad(utf8_binary, 8, a)#x, rpad(utf8_lcase#x, 8, cast(a as string collate UTF8_LCASE)) AS rpad(utf8_lcase, 8, a)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2464,7 +2464,7 @@ Project [lpad(collate(utf8_binary#x, utf8_binary_rtrim), 8, collate(utf8_lcase#x -- !query select lpad(utf8_binary, 8, 'a'), lpad(utf8_lcase, 8, 'a') from t5 -- !query analysis -Project [lpad(utf8_binary#x, 8, a) AS lpad(utf8_binary, 8, a)#x, lpad(utf8_lcase#x, 8, a) AS lpad(utf8_lcase, 8, a)#x] +Project [lpad(utf8_binary#x, 8, a) AS lpad(utf8_binary, 8, a)#x, lpad(utf8_lcase#x, 8, cast(a as string collate UTF8_LCASE)) AS lpad(utf8_lcase, 8, a)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2554,7 +2554,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5 -- !query analysis -Project [locate(utf8_binary#x, a, 1) AS locate(utf8_binary, a, 1)#x, locate(utf8_lcase#x, a, 1) AS locate(utf8_lcase, a, 1)#x] +Project [locate(utf8_binary#x, a, 1) AS locate(utf8_binary, a, 1)#x, locate(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 1) AS locate(utf8_lcase, a, 1)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2660,7 +2660,7 @@ Project [trim(collate(utf8_lcase#x, utf8_binary_rtrim), Some(collate(utf8_binary -- !query select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5 -- !query analysis -Project [trim(utf8_binary#x, Some(ABc)) AS TRIM(BOTH ABc FROM utf8_binary)#x, trim(utf8_lcase#x, Some(ABc)) AS TRIM(BOTH ABc FROM utf8_lcase)#x] +Project [trim(utf8_binary#x, Some(ABc)) AS TRIM(BOTH ABc FROM utf8_binary)#x, trim(utf8_lcase#x, Some(cast(ABc as string collate UTF8_LCASE))) AS TRIM(BOTH ABc FROM utf8_lcase)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2856,7 +2856,7 @@ Project [ltrim(collate(utf8_lcase#x, utf8_binary_rtrim), Some(collate(utf8_binar -- !query select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5 -- !query analysis -Project [ltrim(utf8_binary#x, Some(ABc)) AS TRIM(LEADING ABc FROM utf8_binary)#x, ltrim(utf8_lcase#x, Some(ABc)) AS TRIM(LEADING ABc FROM utf8_lcase)#x] +Project [ltrim(utf8_binary#x, Some(ABc)) AS TRIM(LEADING ABc FROM utf8_binary)#x, ltrim(utf8_lcase#x, Some(cast(ABc as string collate UTF8_LCASE))) AS TRIM(LEADING ABc FROM utf8_lcase)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet @@ -2954,7 +2954,7 @@ Project [rtrim(collate(utf8_lcase#x, utf8_binary_rtrim), Some(collate(utf8_binar -- !query select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5 -- !query analysis -Project [rtrim(utf8_binary#x, Some(ABc)) AS TRIM(TRAILING ABc FROM utf8_binary)#x, rtrim(utf8_lcase#x, Some(ABc)) AS TRIM(TRAILING ABc FROM utf8_lcase)#x] +Project [rtrim(utf8_binary#x, Some(ABc)) AS TRIM(TRAILING ABc FROM utf8_binary)#x, rtrim(utf8_lcase#x, Some(cast(ABc as string collate UTF8_LCASE))) AS TRIM(TRAILING ABc FROM utf8_lcase)#x] +- SubqueryAlias spark_catalog.default.t5 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out index 988df7de1a3cf..78539effe188e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out @@ -115,7 +115,7 @@ org.apache.spark.sql.AnalysisException -- !query select * from values ("one", 2.0), ("two") as data(a, b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", @@ -157,7 +157,7 @@ org.apache.spark.sql.AnalysisException -- !query select * from values ("one"), ("two") as data(a, b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out index 2333cce874d31..f042116182f7d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out @@ -498,7 +498,7 @@ SELECT a, b, SUM(b) OVER(ORDER BY A ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.FAILED_SQL_EXPRESSION_EVALUATION", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index e96549f00d6ec..f92fc5de8c3f4 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -479,7 +479,7 @@ struct +struct<(collate(a, unicode) < A):boolean> -- !query output true @@ -487,7 +487,7 @@ true -- !query select 'a' collate unicode_ci = 'A' -- !query schema -struct<(collate(a, unicode_ci) = 'A' collate UNICODE_CI):boolean> +struct<(collate(a, unicode_ci) = A):boolean> -- !query output true @@ -495,7 +495,7 @@ true -- !query select 'a' collate unicode_ai = 'å' -- !query schema -struct<(collate(a, unicode_ai) = 'å' collate UNICODE_AI):boolean> +struct<(collate(a, unicode_ai) = å):boolean> -- !query output true @@ -503,7 +503,7 @@ true -- !query select 'a' collate unicode_ci_ai = 'Å' -- !query schema -struct<(collate(a, unicode_ci_ai) = 'Å' collate UNICODE_CI_AI):boolean> +struct<(collate(a, unicode_ci_ai) = Å):boolean> -- !query output true @@ -511,7 +511,7 @@ true -- !query select 'a' collate en < 'A' -- !query schema -struct<(collate(a, en) < 'A' collate en):boolean> +struct<(collate(a, en) < A):boolean> -- !query output true @@ -519,7 +519,7 @@ true -- !query select 'a' collate en_ci = 'A' -- !query schema -struct<(collate(a, en_ci) = 'A' collate en_CI):boolean> +struct<(collate(a, en_ci) = A):boolean> -- !query output true @@ -527,7 +527,7 @@ true -- !query select 'a' collate en_ai = 'å' -- !query schema -struct<(collate(a, en_ai) = 'å' collate en_AI):boolean> +struct<(collate(a, en_ai) = å):boolean> -- !query output true @@ -535,7 +535,7 @@ true -- !query select 'a' collate en_ci_ai = 'Å' -- !query schema -struct<(collate(a, en_ci_ai) = 'Å' collate en_CI_AI):boolean> +struct<(collate(a, en_ci_ai) = Å):boolean> -- !query output true @@ -543,7 +543,7 @@ true -- !query select 'Kypper' collate sv < 'Köpfe' -- !query schema -struct<(collate(Kypper, sv) < 'Köpfe' collate sv):boolean> +struct<(collate(Kypper, sv) < Köpfe):boolean> -- !query output true @@ -551,7 +551,7 @@ true -- !query select 'Kypper' collate de > 'Köpfe' -- !query schema -struct<(collate(Kypper, de) > 'Köpfe' collate de):boolean> +struct<(collate(Kypper, de) > Köpfe):boolean> -- !query output true @@ -559,7 +559,7 @@ true -- !query select 'I' collate tr_ci = 'ı' -- !query schema -struct<(collate(I, tr_ci) = 'ı' collate tr_CI):boolean> +struct<(collate(I, tr_ci) = ı):boolean> -- !query output true @@ -1109,7 +1109,7 @@ kitten -- !query select elt(1, utf8_binary, 'word'), elt(1, utf8_lcase, 'word') from t5 -- !query schema -struct +struct -- !query output Hello, world! Nice day. Hello, world! Nice day. Something else. Nothing here. Something else. Nothing here. @@ -2492,7 +2492,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputType" : "\"STRING COLLATE UNICODE_AI\"", "paramIndex" : "first", "requiredType" : "\"STRING\"", - "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 'abc' collate UNICODE_AI)\"" + "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\"" }, "queryContext" : [ { "objectType" : "", @@ -3342,7 +3342,7 @@ ksitTing -- !query select overlay(utf8_binary, 'a', 2), overlay(utf8_lcase, 'a', 2) from t5 -- !query schema -struct +struct -- !query output Hallo, world! Nice day. Hallo, world! Nice day. Saark SaL diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index 4dcdf8ac3e980..0a2c7b0f55ed2 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -131,7 +131,7 @@ select * from values ("one", 2.0), ("two") as data(a, b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", @@ -177,7 +177,7 @@ select * from values ("one"), ("two") as data(a, b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out index 2085186dc8cfa..2d539725b2a70 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out @@ -497,7 +497,7 @@ FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.FAILED_SQL_EXPRESSION_EVALUATION", "sqlState" : "42000", From 70abc52847abf0acf86b81d798f09cead53bae2e Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 26 Nov 2024 09:02:50 +0100 Subject: [PATCH 42/54] fix regex failing test --- .../scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala index cd713016a96e1..57c4441b165ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala @@ -452,8 +452,7 @@ class CollationSQLRegexpSuite }, condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( - "sqlExpr" -> ("\"regexp_replace(collate(ABCDE, UNICODE_CI), " + - ".c., 'FFF' collate UNICODE_CI, 1)\""), + "sqlExpr" -> "\"regexp_replace(collate(ABCDE, UNICODE_CI), .c., FFF, 1)\"", "paramIndex" -> "first", "inputSql" -> "\"collate(ABCDE, UNICODE_CI)\"", "inputType" -> "\"STRING COLLATE UNICODE_CI\"", From 3e6e7cb114ba7efd74279886745aa7360326b997 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 26 Nov 2024 11:49:34 +0100 Subject: [PATCH 43/54] fix failing tests --- .../apache/spark/sql/types/StringType.scala | 4 -- .../analysis/CollationTypeCoercion.scala | 4 ++ .../analysis/ResolveDefaultStringTypes.scala | 43 ++++++++++++------- .../analysis/ResolveInlineTables.scala | 6 ++- .../sql/catalyst/parser/AstBuilder.scala | 17 ++++---- .../catalyst/plans/logical/v2Commands.scala | 4 ++ .../sql/catalyst/rules/RuleExecutor.scala | 1 - .../analysis/ResolveInlineTablesSuite.scala | 3 +- .../sql/catalyst/parser/DDLParserSuite.scala | 4 +- .../sql/CollationStringExpressionsSuite.scala | 14 ------ .../collation/DefaultCollationTestSuite.scala | 10 +++-- 11 files changed, 60 insertions(+), 50 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index b130453845654..695d222a9ba73 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -77,10 +77,6 @@ class StringType private[sql] (val collationId: Int) extends AtomicType with Ser if (isUTF8BinaryCollation) "string" else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}" - override def toString: String = - if (isUTF8BinaryCollation) "StringType" - else s"StringType($collationId)" - // Due to backwards compatibility and compatibility with other readers // all string types are serialized in json as regular strings and // the collation information is written to struct field metadata diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index e87fa4860fe88..c6d5fc8a482a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -214,6 +214,10 @@ object CollationTypeCoercion { if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) { // if there are no collated types we don't need to do anything return None + } else if (ResolveDefaultStringTypes.doesNeedResolve(expressions)) { + // if any of the strings types are still not resolved + // we need to wait for them to be resolved first + return None } val collationContextWinner = expressions.foldLeft(findCollationContext(expressions.head)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 8c2baaa6763a5..5e2bb7c093082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal, NamedLambdaVariable} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -53,6 +53,22 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic } } + /** + * Returns whether any of the given `expressions` needs to have its + * default string type resolved. + */ + def doesNeedResolve(expressions: Seq[Expression]): Boolean = { + expressions.exists(doesNeedResolve) + } + + /** + * Returns whether the given `expression` needs to have its + * default string type resolved. + */ + def doesNeedResolve(expression: Expression): Boolean = { + expression.exists(e => transformExpression.isDefinedAt(e, StringType)) + } + private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType /** @@ -106,30 +122,25 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic private def transformPlan(plan: LogicalPlan, newType: StringType): LogicalPlan = { plan resolveOperators { operator => operator resolveExpressionsUp { expression => - transformExpression(expression, newType) + transformExpression.applyOrElse((expression, newType), fallbackToExpression) } } } + private def fallbackToExpression(tuple: (Expression, StringType)): Expression = tuple._1 + /** * Transforms the given expression, by changing all default string types to the given new type. */ - private def transformExpression(expression: Expression, newType: StringType): Expression = { - expression match { - case columnDef: ColumnDefinition if hasDefaultStringType(columnDef.dataType) => - columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType)) + private def transformExpression: PartialFunction[(Expression, StringType), Expression] = { + case (columnDef: ColumnDefinition, newType) if hasDefaultStringType(columnDef.dataType) => + columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType)) - case cast: Cast if hasDefaultStringType(cast.dataType) => - cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType)) + case (cast: Cast, newType) if hasDefaultStringType(cast.dataType) => + cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType)) - case Literal(value, dt) if hasDefaultStringType(dt) => - Literal(value, replaceDefaultStringType(dt, newType)) - - case lambdaVar: NamedLambdaVariable if hasDefaultStringType(lambdaVar.dataType) => - lambdaVar.copy(dataType = replaceDefaultStringType(lambdaVar.dataType, newType)) - - case other => other - } + case (Literal(value, dt), newType) if hasDefaultStringType(dt) => + Literal(value, replaceDefaultStringType(dt, newType)) } private def hasDefaultStringType(dataType: DataType): Boolean = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 62f3997491c07..2f08a3a872fbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -29,8 +29,12 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess object ResolveInlineTables extends Rule[LogicalPlan] with EvalHelper { override def apply(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsWithPruning(AlwaysProcess.fn, ruleId) { - case table: UnresolvedInlineTable if table.expressionsResolved => + case table: UnresolvedInlineTable if canResolveTable(table) => EvaluateUnresolvedInlineTable.evaluateUnresolvedInlineTable(table) } } + + private def canResolveTable(table: UnresolvedInlineTable): Boolean = { + table.expressionsResolved && !ResolveDefaultStringTypes.doesNeedResolve(table.expressions) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f48548bde5f13..0d1d14532ef51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2127,7 +2127,7 @@ class AstBuilder extends DataTypeAstBuilder } val unresolvedTable = UnresolvedInlineTable(aliases, rows.toSeq) - val table = if (canEagerlyEvaluateInlineTable(unresolvedTable, ctx)) { + val table = if (canEagerlyEvaluateInlineTable(ctx, unresolvedTable)) { EvaluateUnresolvedInlineTable.evaluate(unresolvedTable) } else { unresolvedTable @@ -2139,19 +2139,20 @@ class AstBuilder extends DataTypeAstBuilder * Determines if the inline table can be eagerly evaluated. */ private def canEagerlyEvaluateInlineTable( - table: UnresolvedInlineTable, - ctx: InlineTableContext): Boolean = { + ctx: InlineTableContext, + table: UnresolvedInlineTable): Boolean = { if (!conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { return false + } else if (!ResolveDefaultStringTypes.doesNeedResolve(table.expressions)) { + // if there are no strings to be resolved we can always evaluate eagerly + return true } val isSessionCollationSet = conf.defaultStringType != StringType - // if either of these are true we also need to check string types to make sure - // that there are no unresolved string types in the expressions - val checkStringTypes = isSessionCollationSet || contextInsideCreate(ctx) - - !checkStringTypes + // if either of these are true we need to resolve + // the string types first + !isSessionCollationSet && !contextInsideCreate(ctx) } private def contextInsideCreate(ctx: ParserRuleContext): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index bffefbf2f4fea..a438fad618fbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -459,6 +459,10 @@ trait V2CreateTableAsSelectPlan newQuery: LogicalPlan): V2CreateTableAsSelectPlan } +/** + * A trait used for logical plan nodes that create V1 table definitions, + * and so that rules from the catalyst module can identify them. + */ trait V1CreateTablePlan extends LogicalPlan {} /** A trait used for logical plan nodes that create or replace V2 table definitions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 3fcf9c13005e3..76d36fab2096a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils object RuleExecutor { - protected val queryExecutionMeter = QueryExecutionMetering() /** Dump statistics about time spent running specific rules. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala index f231164d5c25a..e2e8df6aeddeb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -117,8 +117,9 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) val withTimeZone = ResolveTimeZone.apply(table) + val resolvedStringTypes = ResolveDefaultStringTypes.apply(withTimeZone) val LocalRelation(output, data, _, _) = - EvalInlineTables(ResolveInlineTables.apply(withTimeZone)) + EvalInlineTables(ResolveInlineTables.apply(resolvedStringTypes)) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] assert(output.map(_.dataType) == Seq(TimestampType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 5e871208698af..062a7e8930933 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -2601,7 +2601,9 @@ class DDLParserSuite extends AnalysisTest { Map("part" -> Some(part)), Seq.empty[String], if (optimizeInsertIntoCmds) { - ResolveInlineTables(UnresolvedInlineTable(Seq("col1"), Seq(Seq(Literal("a"))))) + val resolvedStringTypes = ResolveDefaultStringTypes( + UnresolvedInlineTable(Seq("col1"), Seq(Seq(Literal("a"))))) + ResolveInlineTables(resolvedStringTypes) } else { UnresolvedInlineTable(Seq("col1"), Seq(Seq(Literal("a")))) }, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 2a0b84c075079..26eb79c38f39d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -194,20 +194,6 @@ class CollationStringExpressionsSuite "implicitTypes" -> """"STRING", "STRING COLLATE UTF8_LCASE"""" ) ) - - checkError( - exception = intercept[AnalysisException] { - val expr = StringSplitSQL( - Collate(Literal.create("1a2"), "UTF8_BINARY"), - Collate(Literal.create("a"), "UTF8_LCASE")) - CollationTypeCasts.transform(expr) - }, - condition = "COLLATION_MISMATCH.EXPLICIT", - sqlState = "42P21", - parameters = Map( - "explicitTypes" -> """"STRING", "STRING COLLATE UTF8_LCASE"""" - ) - ) } test("Support `Contains` string expression with collation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index ed3fac2311704..0de638d4e9bf9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -190,8 +190,9 @@ abstract class DefaultCollationTestSuite extends QueryTest with SharedSparkSessi test("inline table in CTAS") { withSessionCollationAndTable("UTF8_LCASE", testTable) { sql(s""" - |CREATE TABLE $testTable AS - |SELECT * + |CREATE TABLE $testTable + |USING $dataSource + |AS SELECT * |FROM (VALUES ('a', 'a' = 'A')) |AS inline_table(c1, c2); |""".stripMargin) @@ -475,8 +476,9 @@ class DefaultCollationTestSuiteV2 extends DefaultCollationTestSuite with Datasou withSessionCollationAndTable("UTF8_LCASE", testTable) { sql(s"CREATE TABLE $testTable (c1 STRING, c2 BOOLEAN) USING $dataSource") sql(s""" - |REPLACE TABLE $testTable AS - |SELECT * + |REPLACE TABLE $testTable + |USING $dataSource + |AS SELECT * |FROM (VALUES ('a', 'a' = 'A')) |AS inline_table(c1, c2); |""".stripMargin) From 83de6e18db4ff6d5a8b31c58d36f66822915de13 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 26 Nov 2024 12:04:00 +0100 Subject: [PATCH 44/54] add test back --- .../sql/CollationStringExpressionsSuite.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 26eb79c38f39d..626bd0b239366 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -194,6 +194,20 @@ class CollationStringExpressionsSuite "implicitTypes" -> """"STRING", "STRING COLLATE UTF8_LCASE"""" ) ) + + checkError( + exception = intercept[AnalysisException] { + val expr = StringSplitSQL( + Collate(Literal.create("1a2", StringType("UTF8_BINARY")), "UTF8_BINARY"), + Collate(Literal.create("a", StringType("UTF8_BINARY")), "UTF8_LCASE")) + CollationTypeCasts.transform(expr) + }, + condition = "COLLATION_MISMATCH.EXPLICIT", + sqlState = "42P21", + parameters = Map( + "explicitTypes" -> """"STRING", "STRING COLLATE UTF8_LCASE"""" + ) + ) } test("Support `Contains` string expression with collation") { From 2c8b567f8d1fb86526457ad91456eb93f61524b2 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 26 Nov 2024 13:25:15 +0100 Subject: [PATCH 45/54] add proper toString method for StringType --- .../scala/org/apache/spark/sql/types/StringType.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 695d222a9ba73..f9fcd2ab57f83 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -75,7 +75,13 @@ class StringType private[sql] (val collationId: Int) extends AtomicType with Ser */ override def typeName: String = if (isUTF8BinaryCollation) "string" - else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}" + else s"string collate $collationName" + + override def toString: String = + if (isUTF8BinaryCollation) "StringType" + else s"StringType($collationName)" + + def collationName: String = CollationFactory.fetchCollation(collationId).collationName // Due to backwards compatibility and compatibility with other readers // all string types are serialized in json as regular strings and From 10033a0c9d5c92271d432d82349bf6c0d45cf617 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 27 Nov 2024 14:05:32 +0100 Subject: [PATCH 46/54] rename methods --- .../apache/spark/sql/types/StringType.scala | 3 ++- .../analysis/CollationTypeCoercion.scala | 2 +- .../analysis/ResolveDefaultStringTypes.scala | 12 +++++----- .../analysis/ResolveInlineTables.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../catalyst/plans/logical/v2Commands.scala | 2 +- .../spark/sql/CollationSQLRegexpSuite.scala | 24 +++++++++---------- 7 files changed, 23 insertions(+), 24 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index f9fcd2ab57f83..b2cf502f8bdc1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -81,7 +81,8 @@ class StringType private[sql] (val collationId: Int) extends AtomicType with Ser if (isUTF8BinaryCollation) "StringType" else s"StringType($collationName)" - def collationName: String = CollationFactory.fetchCollation(collationId).collationName + private[sql] def collationName: String = + CollationFactory.fetchCollation(collationId).collationName // Due to backwards compatibility and compatibility with other readers // all string types are serialized in json as regular strings and diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index c6d5fc8a482a2..cca1d21df3a7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -214,7 +214,7 @@ object CollationTypeCoercion { if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) { // if there are no collated types we don't need to do anything return None - } else if (ResolveDefaultStringTypes.doesNeedResolve(expressions)) { + } else if (ResolveDefaultStringTypes.needsResolution(expressions)) { // if any of the strings types are still not resolved // we need to wait for them to be resolved first return None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 5e2bb7c093082..e1b5d479719a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -57,15 +57,15 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic * Returns whether any of the given `expressions` needs to have its * default string type resolved. */ - def doesNeedResolve(expressions: Seq[Expression]): Boolean = { - expressions.exists(doesNeedResolve) + def needsResolution(expressions: Seq[Expression]): Boolean = { + expressions.exists(needsResolution) } /** * Returns whether the given `expression` needs to have its * default string type resolved. */ - def doesNeedResolve(expression: Expression): Boolean = { + def needsResolution(expression: Expression): Boolean = { expression.exists(e => transformExpression.isDefinedAt(e, StringType)) } @@ -169,10 +169,10 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic } private def getTemporaryStringType(forType: StringType): StringType = { - if (forType.collationId == 0) { - TemporaryStringType(1) + if (forType == StringType) { + TemporaryStringType(StringType("UTF8_LCASE").collationId) } else { - TemporaryStringType(0) + TemporaryStringType(StringType.collationId) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 2f08a3a872fbf..9efb1d5b57509 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -35,6 +35,6 @@ object ResolveInlineTables extends Rule[LogicalPlan] with EvalHelper { } private def canResolveTable(table: UnresolvedInlineTable): Boolean = { - table.expressionsResolved && !ResolveDefaultStringTypes.doesNeedResolve(table.expressions) + table.expressionsResolved && !ResolveDefaultStringTypes.needsResolution(table.expressions) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0d1d14532ef51..6f294c3edf1eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2143,7 +2143,7 @@ class AstBuilder extends DataTypeAstBuilder table: UnresolvedInlineTable): Boolean = { if (!conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { return false - } else if (!ResolveDefaultStringTypes.doesNeedResolve(table.expressions)) { + } else if (!ResolveDefaultStringTypes.needsResolution(table.expressions)) { // if there are no strings to be resolved we can always evaluate eagerly return true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index a438fad618fbb..857522728eaff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -463,7 +463,7 @@ trait V2CreateTableAsSelectPlan * A trait used for logical plan nodes that create V1 table definitions, * and so that rules from the catalyst module can identify them. */ -trait V1CreateTablePlan extends LogicalPlan {} +trait V1CreateTablePlan extends LogicalPlan /** A trait used for logical plan nodes that create or replace V2 table definitions. */ trait V2CreateTablePlan extends LogicalPlan { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala index 57c4441b165ac..7cafb999ffcf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala @@ -111,19 +111,17 @@ class CollationSQLRegexpSuite } val tableNameLcase = "T_LCASE" withTable(tableNameLcase) { - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_LCASE") { - sql(s""" - |CREATE TABLE IF NOT EXISTS $tableNameLcase( - | c STRING COLLATE UTF8_LCASE - |) using PARQUET - |""".stripMargin) - sql(s"INSERT INTO $tableNameLcase(c) VALUES('ABC')") - checkAnswer(sql(s"select c like 'ab%' FROM $tableNameLcase"), Row(true)) - checkAnswer(sql(s"select c like '%bc' FROM $tableNameLcase"), Row(true)) - checkAnswer(sql(s"select c like 'a%c' FROM $tableNameLcase"), Row(true)) - checkAnswer(sql(s"select c like '%b%' FROM $tableNameLcase"), Row(true)) - checkAnswer(sql(s"select c like 'abc' FROM $tableNameLcase"), Row(true)) - } + sql(s""" + |CREATE TABLE IF NOT EXISTS $tableNameLcase( + | c STRING COLLATE UTF8_LCASE + |) using PARQUET + |""".stripMargin) + sql(s"INSERT INTO $tableNameLcase(c) VALUES('ABC')") + checkAnswer(sql(s"select c like 'ab%' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like '%bc' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like 'a%c' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like '%b%' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like 'abc' FROM $tableNameLcase"), Row(true)) } } From 9b3035a0efa474210ce750c36382042aab9bcd03 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 27 Nov 2024 16:09:27 +0100 Subject: [PATCH 47/54] address pr comments --- .../apache/spark/sql/types/ArrayType.scala | 9 ++++++ .../org/apache/spark/sql/types/DataType.scala | 21 +------------ .../org/apache/spark/sql/types/MapType.scala | 9 ++++++ .../apache/spark/sql/types/StructType.scala | 12 ++++++++ .../analysis/ResolveDefaultStringTypes.scala | 30 ++++++++++++++----- .../analysis/ResolveInlineTables.scala | 2 +- .../sql/catalyst/rules/RuleExecutor.scala | 17 +++++++++-- .../analysis/ResolveInlineTablesSuite.scala | 3 +- .../sql/catalyst/parser/DDLParserSuite.scala | 4 +-- 9 files changed, 72 insertions(+), 35 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index fc32248b4baf3..53dfc5e9b2828 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -110,4 +110,13 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || elementType.existsRecursively(f) } + + override private[spark] def transformRecursively( + f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) { + f(this) + } else { + ArrayType(elementType.transformRecursively(f), containsNull) + } + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 69c87cd2ef99e..683582f43df4b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -109,26 +109,7 @@ abstract class DataType extends AbstractDataType { * Recursively applies the provided partial function `f` to transform this DataType tree. */ private[spark] def transformRecursively(f: PartialFunction[DataType, DataType]): DataType = { - this match { - case _ if f.isDefinedAt(this) => - f(this) - - case ArrayType(elementType, containsNull) => - ArrayType(elementType.transformRecursively(f), containsNull) - - case MapType(keyType, valueType, valueContainsNull) => - MapType( - keyType.transformRecursively(f), - valueType.transformRecursively(f), - valueContainsNull) - - case StructType(fields) => - StructType(fields.map { field => - field.copy(dataType = field.dataType.transformRecursively(f)) - }) - - case _ => this - } + if (f.isDefinedAt(this)) f(this) else this } final override private[sql] def defaultConcreteType: DataType = this diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala index 1dfb9aaf9e29b..73eceede93918 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -89,6 +89,15 @@ case class MapType(keyType: DataType, valueType: DataType, valueContainsNull: Bo override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) } + + override private[spark] def transformRecursively( + f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) { + f(this) + } else { + MapType(keyType.transformRecursively(f), valueType.transformRecursively(f), valueContainsNull) + } + } } /** diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala index 07f6b50bd4a7a..cc95d8ee94b02 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -502,6 +502,18 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || fields.exists(field => field.dataType.existsRecursively(f)) } + + override private[spark] def transformRecursively( + f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) { + return f(this) + } + + val newFields = fields.map { field => + field.copy(dataType = field.dataType.transformRecursively(f)) + } + StructType(newFields) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index e1b5d479719a4..70bffc6fb1624 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan} -import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StringType} @@ -36,21 +36,37 @@ import org.apache.spark.sql.types.{DataType, StringType} class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val newPlan = if (isDDLCommand(plan)) { - transformDDL(plan) - } else { - val newType = stringTypeForDMLCommand - transformPlan(plan, newType) + val newPlan = plan match { + case _ if isDDLCommand(plan) => transformDDL(plan) + case _ if isDefaultSessionCollationUsed => plan + case _ => transformPlan(plan, stringTypeForDMLCommand) } if (!replaceWithTempType || newPlan.fastEquals(plan)) { + newPlan.unsetTagValue(RuleExecutor.FORCE_ANOTHER_BATCH_ITER) newPlan } else { // Due to how tree transformations work and StringType object being equal to // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice // to ensure the correct results for occurrences of default string type. - ResolveDefaultStringTypesWithoutTempType.apply(newPlan) + val finalPlan = ResolveDefaultStringTypesWithoutTempType.apply(newPlan) + finalPlan.setTagValue(RuleExecutor.FORCE_ANOTHER_BATCH_ITER, ()) + finalPlan + } + } + + /** + * Returns whether any of the given `plan` needs to have its + * default string type resolved. + */ + def needsResolution(plan: LogicalPlan): Boolean = { + if (isDDLCommand(plan)) { + return true + } else if (!isDefaultSessionCollationUsed) { + return false } + + plan.exists(node => needsResolution(node.expressions)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 9efb1d5b57509..b9e9e49a39647 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -35,6 +35,6 @@ object ResolveInlineTables extends Rule[LogicalPlan] with EvalHelper { } private def canResolveTable(table: UnresolvedInlineTable): Boolean = { - table.expressionsResolved && !ResolveDefaultStringTypes.needsResolution(table.expressions) + table.expressionsResolved && !ResolveDefaultStringTypes.needsResolution(table) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 76d36fab2096a..ba1971aac6dfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -22,7 +22,7 @@ import org.apache.spark.internal.{Logging, MessageWithContext} import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.MDC import org.apache.spark.sql.catalyst.QueryPlanningTracker -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.errors.QueryExecutionErrors @@ -30,6 +30,14 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils object RuleExecutor { + + /** + * A tag to indicate that we should do another batch iteration even if the plan + * hasn't changed between the start and end of the batch. + * Use with caution as it can lead to infinite loops. + */ + private[spark] val FORCE_ANOTHER_BATCH_ITER = TreeNodeTag[Unit]("forceAnotherBatchIter") + protected val queryExecutionMeter = QueryExecutionMetering() /** Dump statistics about time spent running specific rules. */ @@ -303,7 +311,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { continue = false } - if (curPlan.fastEquals(lastPlan)) { + if (isFixedPointReached(lastPlan, curPlan)) { logTrace( s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.") continue = false @@ -317,4 +325,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { curPlan } + + def isFixedPointReached(oldPlan: TreeType, newPlan: TreeType): Boolean = { + oldPlan.fastEquals(newPlan) && + newPlan.getTagValue(RuleExecutor.FORCE_ANOTHER_BATCH_ITER).isEmpty + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala index e2e8df6aeddeb..f231164d5c25a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -117,9 +117,8 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) val withTimeZone = ResolveTimeZone.apply(table) - val resolvedStringTypes = ResolveDefaultStringTypes.apply(withTimeZone) val LocalRelation(output, data, _, _) = - EvalInlineTables(ResolveInlineTables.apply(resolvedStringTypes)) + EvalInlineTables(ResolveInlineTables.apply(withTimeZone)) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] assert(output.map(_.dataType) == Seq(TimestampType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 062a7e8930933..5e871208698af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -2601,9 +2601,7 @@ class DDLParserSuite extends AnalysisTest { Map("part" -> Some(part)), Seq.empty[String], if (optimizeInsertIntoCmds) { - val resolvedStringTypes = ResolveDefaultStringTypes( - UnresolvedInlineTable(Seq("col1"), Seq(Seq(Literal("a"))))) - ResolveInlineTables(resolvedStringTypes) + ResolveInlineTables(UnresolvedInlineTable(Seq("col1"), Seq(Seq(Literal("a"))))) } else { UnresolvedInlineTable(Seq("col1"), Seq(Seq(Literal("a")))) }, From 59190ec3e7e6d99f366147eaadfb433e6748cbea Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 27 Nov 2024 17:09:07 +0100 Subject: [PATCH 48/54] fix minor error in needsResolution logic --- .../spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 70bffc6fb1624..5f92f22b042e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -62,7 +62,7 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic def needsResolution(plan: LogicalPlan): Boolean = { if (isDDLCommand(plan)) { return true - } else if (!isDefaultSessionCollationUsed) { + } else if (isDefaultSessionCollationUsed) { return false } From 311cdd44a49fb27b38904220cc8fc46babc6239d Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 27 Nov 2024 17:14:04 +0100 Subject: [PATCH 49/54] fix formatting --- .../src/main/scala/org/apache/spark/sql/types/MapType.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala index 73eceede93918..de656c13ca4bf 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -95,7 +95,10 @@ case class MapType(keyType: DataType, valueType: DataType, valueContainsNull: Bo if (f.isDefinedAt(this)) { f(this) } else { - MapType(keyType.transformRecursively(f), valueType.transformRecursively(f), valueContainsNull) + MapType( + keyType.transformRecursively(f), + valueType.transformRecursively(f), + valueContainsNull) } } } From d92a5cf592914354d077ef2f427118f2bd6c7a9f Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 27 Nov 2024 19:03:11 +0100 Subject: [PATCH 50/54] bring back if/else --- .../sql/catalyst/analysis/ResolveDefaultStringTypes.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 5f92f22b042e7..6ea8a37ff9e86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -36,10 +36,10 @@ import org.apache.spark.sql.types.{DataType, StringType} class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val newPlan = plan match { - case _ if isDDLCommand(plan) => transformDDL(plan) - case _ if isDefaultSessionCollationUsed => plan - case _ => transformPlan(plan, stringTypeForDMLCommand) + val newPlan = if (isDDLCommand(plan)) { + transformDDL(plan) + } else { + transformPlan(plan, stringTypeForDMLCommand) } if (!replaceWithTempType || newPlan.fastEquals(plan)) { From 26909ea61f8105bd011672d1d7c1b0606188c2d5 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 27 Nov 2024 23:30:25 +0100 Subject: [PATCH 51/54] change for another iteration to be an integer --- .../analysis/ResolveDefaultStringTypes.scala | 3 +-- .../sql/catalyst/rules/RuleExecutor.scala | 26 ++++++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 6ea8a37ff9e86..2695fe4b54393 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -43,14 +43,13 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic } if (!replaceWithTempType || newPlan.fastEquals(plan)) { - newPlan.unsetTagValue(RuleExecutor.FORCE_ANOTHER_BATCH_ITER) newPlan } else { // Due to how tree transformations work and StringType object being equal to // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice // to ensure the correct results for occurrences of default string type. val finalPlan = ResolveDefaultStringTypesWithoutTempType.apply(newPlan) - finalPlan.setTagValue(RuleExecutor.FORCE_ANOTHER_BATCH_ITER, ()) + RuleExecutor.forceAdditionalIteration(finalPlan) finalPlan } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index ba1971aac6dfc..bdbf698db2e01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -22,6 +22,7 @@ import org.apache.spark.internal.{Logging, MessageWithContext} import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.MDC import org.apache.spark.sql.catalyst.QueryPlanningTracker +import org.apache.spark.sql.catalyst.rules.RuleExecutor.getForceIterationValue import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.sideBySide @@ -32,11 +33,24 @@ import org.apache.spark.util.Utils object RuleExecutor { /** - * A tag to indicate that we should do another batch iteration even if the plan - * hasn't changed between the start and end of the batch. - * Use with caution as it can lead to infinite loops. + * A tag used to explicitly request an additional iteration of the current batch during + * rule execution, even if the query plan remains unchanged. Increment the tag's value + * to enforce another iteration. */ - private[spark] val FORCE_ANOTHER_BATCH_ITER = TreeNodeTag[Unit]("forceAnotherBatchIter") + private val FORCE_ADDITIONAL_ITERATION = TreeNodeTag[Int]("forceAdditionalIteration") + + /** + * Increments the value of the FORCE_ADDITIONAL_ITERATION tag on the given plan to + * explicitly force another iteration of the current batch during rule execution. + */ + def forceAdditionalIteration(plan: TreeNode[_]): Unit = { + val oldValue = getForceIterationValue(plan) + plan.setTagValue(FORCE_ADDITIONAL_ITERATION, oldValue + 1) + } + + private def getForceIterationValue(plan: TreeNode[_]): Int = { + plan.getTagValue(FORCE_ADDITIONAL_ITERATION).getOrElse(0) + } protected val queryExecutionMeter = QueryExecutionMetering() @@ -326,8 +340,8 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { curPlan } - def isFixedPointReached(oldPlan: TreeType, newPlan: TreeType): Boolean = { + private def isFixedPointReached(oldPlan: TreeType, newPlan: TreeType): Boolean = { oldPlan.fastEquals(newPlan) && - newPlan.getTagValue(RuleExecutor.FORCE_ANOTHER_BATCH_ITER).isEmpty + getForceIterationValue(newPlan) <= getForceIterationValue(oldPlan) } } From 203d7e5ed8f8669ad678d73392b5bf3b94deed8b Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 28 Nov 2024 10:34:26 +0100 Subject: [PATCH 52/54] address pr comments --- .../analysis/ResolveDefaultStringTypes.scala | 29 +++++-------------- .../analysis/ResolveSessionCatalog.scala | 5 ++++ 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 2695fe4b54393..34743f5370057 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -39,7 +39,7 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic val newPlan = if (isDDLCommand(plan)) { transformDDL(plan) } else { - transformPlan(plan, stringTypeForDMLCommand) + transformPlan(plan, sessionDefaultStringType) } if (!replaceWithTempType || newPlan.fastEquals(plan)) { @@ -59,9 +59,7 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic * default string type resolved. */ def needsResolution(plan: LogicalPlan): Boolean = { - if (isDDLCommand(plan)) { - return true - } else if (isDefaultSessionCollationUsed) { + if (!isDDLCommand(plan) && isDefaultSessionCollationUsed) { return false } @@ -93,13 +91,9 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic private def stringTypeForDDLCommand(table: LogicalPlan): StringType = StringType("UTF8_BINARY") - /** Returns the default string type that should be used in DML commands. */ - private def stringTypeForDMLCommand: StringType = - if (isDefaultSessionCollationUsed) { - StringType("UTF8_BINARY") - } else { - SQLConf.get.defaultStringType - } + /** Returns the session default string type */ + private def sessionDefaultStringType: StringType = + StringType(SQLConf.get.defaultStringType.collationId) private def isDDLCommand(plan: LogicalPlan): Boolean = plan exists { case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true @@ -176,21 +170,13 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic dataType.transformRecursively { case currentType: StringType if isDefaultStringType(currentType) => if (replaceWithTempType && currentType == newType) { - getTemporaryStringType(currentType) + TemporaryStringType() } else { newType } } } - private def getTemporaryStringType(forType: StringType): StringType = { - if (forType == StringType) { - TemporaryStringType(StringType("UTF8_LCASE").collationId) - } else { - TemporaryStringType(StringType.collationId) - } - } - private def replaceColumnTypes( colTypes: Seq[QualifiedColType], newType: StringType): Seq[QualifiedColType] = { @@ -210,7 +196,6 @@ case object ResolveDefaultStringTypes case object ResolveDefaultStringTypesWithoutTempType extends ResolveDefaultStringTypes(replaceWithTempType = false) {} -case class TemporaryStringType(override val collationId: Int) - extends StringType(collationId) { +case class TemporaryStringType() extends StringType(1) { override def toString: String = s"TemporaryStringType($collationId)" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 92c74f7bede18..5f1ab089cf3e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -54,6 +54,11 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case _ if ResolveDefaultStringTypes.needsResolution(plan) => + // if there are still unresolved string types in the plan + // we should not try to resolve it + plan + case AddColumns(ResolvedV1TableIdentifier(ident), cols) => cols.foreach { c => if (c.name.length > 1) { From 45892fbda2018acb5b1df59bc7706a92c5c927b3 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 28 Nov 2024 14:54:21 +0100 Subject: [PATCH 53/54] fix resolve operators to be only called once --- .../analysis/ResolveDefaultStringTypes.scala | 70 ++++++++----------- 1 file changed, 31 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 34743f5370057..039b0fa4a8f78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StringType} /** @@ -33,24 +32,27 @@ import org.apache.spark.sql.types.{DataType, StringType} * [[TemporaryStringType]] object in cases where the old type and new are equal and thus would * not change the plan after transformation. */ -class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[LogicalPlan] { +object ResolveDefaultStringTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val newPlan = if (isDDLCommand(plan)) { - transformDDL(plan) - } else { - transformPlan(plan, sessionDefaultStringType) - } - - if (!replaceWithTempType || newPlan.fastEquals(plan)) { - newPlan - } else { + val newPlan = apply0(plan) + if (plan.ne(newPlan)) { // Due to how tree transformations work and StringType object being equal to - // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice + // StringType("UTF8_BINARY"), we need to run the rule twice // to ensure the correct results for occurrences of default string type. - val finalPlan = ResolveDefaultStringTypesWithoutTempType.apply(newPlan) + val finalPlan = apply0(newPlan) RuleExecutor.forceAdditionalIteration(finalPlan) finalPlan + } else { + newPlan + } + } + + private def apply0(plan: LogicalPlan): LogicalPlan = { + if (isDDLCommand(plan)) { + transformDDL(plan) + } else { + transformPlan(plan, sessionDefaultStringType) } } @@ -79,10 +81,10 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic * default string type resolved. */ def needsResolution(expression: Expression): Boolean = { - expression.exists(e => transformExpression.isDefinedAt(e, StringType)) + expression.exists(e => transformExpression.isDefinedAt(e)) } - private def isDefaultSessionCollationUsed: Boolean = SQLConf.get.defaultStringType == StringType + private def isDefaultSessionCollationUsed: Boolean = conf.defaultStringType == StringType /** * Returns the default string type that should be used in a given DDL command (for now always @@ -93,7 +95,7 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic /** Returns the session default string type */ private def sessionDefaultStringType: StringType = - StringType(SQLConf.get.defaultStringType.collationId) + StringType(conf.defaultStringType.collationId) private def isDDLCommand(plan: LogicalPlan): Boolean = plan exists { case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true @@ -129,27 +131,25 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic * new type instead of the default string type. */ private def transformPlan(plan: LogicalPlan, newType: StringType): LogicalPlan = { - plan resolveOperators { operator => - operator resolveExpressionsUp { expression => - transformExpression.applyOrElse((expression, newType), fallbackToExpression) - } + plan resolveExpressionsUp { expression => + transformExpression + .andThen(_.apply(newType)) + .applyOrElse(expression, identity[Expression]) } } - private def fallbackToExpression(tuple: (Expression, StringType)): Expression = tuple._1 - /** * Transforms the given expression, by changing all default string types to the given new type. */ - private def transformExpression: PartialFunction[(Expression, StringType), Expression] = { - case (columnDef: ColumnDefinition, newType) if hasDefaultStringType(columnDef.dataType) => - columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType)) + private def transformExpression: PartialFunction[Expression, StringType => Expression] = { + case columnDef: ColumnDefinition if hasDefaultStringType(columnDef.dataType) => + newType => columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType)) - case (cast: Cast, newType) if hasDefaultStringType(cast.dataType) => - cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType)) + case cast: Cast if hasDefaultStringType(cast.dataType) => + newType => cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType)) - case (Literal(value, dt), newType) if hasDefaultStringType(dt) => - Literal(value, replaceDefaultStringType(dt, newType)) + case Literal(value, dt) if hasDefaultStringType(dt) => + newType => Literal(value, replaceDefaultStringType(dt, newType)) } private def hasDefaultStringType(dataType: DataType): Boolean = @@ -157,11 +157,9 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic private def isDefaultStringType(dataType: DataType): Boolean = { dataType match { - case _: TemporaryStringType => - !replaceWithTempType case st: StringType => // should only return true for StringType object and not StringType("UTF8_BINARY") - st.eq(StringType) + st.eq(StringType) || st.isInstanceOf[TemporaryStringType] case _ => false } } @@ -169,7 +167,7 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = { dataType.transformRecursively { case currentType: StringType if isDefaultStringType(currentType) => - if (replaceWithTempType && currentType == newType) { + if (currentType == newType) { TemporaryStringType() } else { newType @@ -190,12 +188,6 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic } } -case object ResolveDefaultStringTypes - extends ResolveDefaultStringTypes(replaceWithTempType = true) {} - -case object ResolveDefaultStringTypesWithoutTempType - extends ResolveDefaultStringTypes(replaceWithTempType = false) {} - case class TemporaryStringType() extends StringType(1) { override def toString: String = s"TemporaryStringType($collationId)" } From a9aa3d06df420d9d46d719f8fc4d79cd9ca78514 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 Nov 2024 10:44:05 +0800 Subject: [PATCH 54/54] Apply suggestions from code review --- .../catalyst/analysis/ResolveDefaultStringTypes.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 039b0fa4a8f78..75958ff3e1177 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -23,22 +23,17 @@ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.types.{DataType, StringType} /** - * Resolves default string types in DDL commands. For DML commands, the default string type is + * Resolves default string types in queries and commands. For queries, the default string type is * determined by the session's default string type. For DDL, the default string type is the * default type of the object (table -> schema -> catalog). However, this is not implemented yet. * So, we will just use UTF8_BINARY for now. - * - * `replaceWithTempType` is a flag that determines whether to replace the default string type with a - * [[TemporaryStringType]] object in cases where the old type and new are equal and thus would - * not change the plan after transformation. */ object ResolveDefaultStringTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val newPlan = apply0(plan) if (plan.ne(newPlan)) { // Due to how tree transformations work and StringType object being equal to - // StringType("UTF8_BINARY"), we need to run the rule twice + // StringType("UTF8_BINARY"), we need to transform the plan twice // to ensure the correct results for occurrences of default string type. val finalPlan = apply0(newPlan) RuleExecutor.forceAdditionalIteration(finalPlan)