Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49992][SQL] Default collation resolution for DDL and DML queries #48962

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
25c057b
initial
stefankandic Oct 12, 2024
0c77855
remove check in CheckAnalysis.scala
stefankandic Oct 12, 2024
0921258
some working version
stefankandic Oct 13, 2024
b0a2139
add fix for eager eval of inline tables
stefankandic Oct 13, 2024
a18d5b9
initial working version with tests
stefankandic Oct 16, 2024
14130fa
change collation id for default type
stefankandic Oct 16, 2024
67e8dbb
fix typo
stefankandic Oct 16, 2024
172d58f
fix failing and add new tests
stefankandic Oct 17, 2024
d3b8f27
merge with master
stefankandic Oct 18, 2024
7ac5654
formatting
stefankandic Oct 18, 2024
82dcbf4
fix toString method
stefankandic Oct 18, 2024
014c855
fix duplicate test name
stefankandic Oct 19, 2024
fd86590
trigger ci
stefankandic Oct 21, 2024
99c9dd2
trigger ci
stefankandic Oct 22, 2024
32d1d1b
add more tests
stefankandic Oct 22, 2024
a50d52c
Merge branch 'master' into fixSessionCollation
stefankandic Oct 22, 2024
29f9a18
add support for create/alter view
stefankandic Oct 23, 2024
82ec0fa
Merge branch 'master' into fixSessionCollation
stefankandic Oct 23, 2024
8df3263
remove explicit collation in map access
stefankandic Oct 23, 2024
fd541a0
do not break parser for StringType
stefankandic Oct 29, 2024
e332022
fix compilation err
stefankandic Oct 29, 2024
7fafc99
fmt
stefankandic Oct 29, 2024
8821abc
Merge branch 'master' into fixSessionCollation
stefankandic Oct 29, 2024
175f703
remove StronglyTypedStringType
stefankandic Nov 13, 2024
d717fed
fix small bug in transform
stefankandic Nov 13, 2024
143d2c0
fix scalastyle
stefankandic Nov 13, 2024
0138f3a
add docstring
stefankandic Nov 13, 2024
27059a8
add check for not using session collation
stefankandic Nov 13, 2024
74414a1
inital
stefankandic Nov 14, 2024
ddfc137
add v1 and v2 tests
stefankandic Nov 14, 2024
e3ed8a0
merge with latest master
stefankandic Nov 14, 2024
d0ad673
add more comments
stefankandic Nov 14, 2024
467790b
fix failing tests
stefankandic Nov 19, 2024
1344580
add v1 api trait back
stefankandic Nov 20, 2024
658331a
fix tests that use schema as a parser
stefankandic Nov 20, 2024
dd823be
format methods
stefankandic Nov 20, 2024
ca7ec5a
formatting
stefankandic Nov 20, 2024
2f93f90
merge with master
stefankandic Nov 21, 2024
7d6dccc
move resolve to spark core
stefankandic Nov 21, 2024
0b0364f
use -1 id for temporary sstring type
stefankandic Nov 21, 2024
052ff47
move the rule back to the catalyst
stefankandic Nov 21, 2024
3e15226
Revert "move the rule back to the catalyst"
stefankandic Nov 25, 2024
86dc81b
initial
stefankandic Nov 25, 2024
c4dd663
add better check for inline tables
stefankandic Nov 25, 2024
95af604
init
stefankandic Nov 25, 2024
ba4defb
improvements
stefankandic Nov 25, 2024
999b335
golden files
stefankandic Nov 26, 2024
70abc52
fix regex failing test
stefankandic Nov 26, 2024
3e6e7cb
fix failing tests
stefankandic Nov 26, 2024
83de6e1
add test back
stefankandic Nov 26, 2024
2c8b567
add proper toString method for StringType
stefankandic Nov 26, 2024
10033a0
rename methods
stefankandic Nov 27, 2024
9b3035a
address pr comments
stefankandic Nov 27, 2024
59190ec
fix minor error in needsResolution logic
stefankandic Nov 27, 2024
311cdd4
fix formatting
stefankandic Nov 27, 2024
d92a5cf
bring back if/else
stefankandic Nov 27, 2024
26909ea
change for another iteration to be an integer
stefankandic Nov 27, 2024
203d7e5
address pr comments
stefankandic Nov 28, 2024
45892fb
fix resolve operators to be only called once
stefankandic Nov 28, 2024
a9aa3d0
Apply suggestions from code review
cloud-fan Nov 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
case (TIMESTAMP_LTZ, Nil) => TimestampType
case (STRING, Nil) =>
typeCtx.children.asScala.toSeq match {
case Seq(_) => SqlApiConf.get.defaultStringType
case Seq(_) => StringType
case Seq(_, ctx: CollateClauseContext) =>
val collationName = visitCollateClause(ctx)
val collationId = CollationFactory.collationNameToId(collationName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.internal.types

import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}

/**
Expand All @@ -26,7 +25,7 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
abstract class AbstractStringType(supportsTrimCollation: Boolean = false)
extends AbstractDataType
with Serializable {
override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType
override private[sql] def defaultConcreteType: DataType = StringType
override private[sql] def simpleString: String = "string"

override private[sql] def acceptsType(other: DataType): Boolean = other match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,13 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
f(this) || elementType.existsRecursively(f)
}

override private[spark] def transformRecursively(
f: PartialFunction[DataType, DataType]): DataType = {
if (f.isDefinedAt(this)) {
f(this)
} else {
ArrayType(elementType.transformRecursively(f), containsNull)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ abstract class DataType extends AbstractDataType {
*/
private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this)

/**
* Recursively applies the provided partial function `f` to transform this DataType tree.
*/
private[spark] def transformRecursively(f: PartialFunction[DataType, DataType]): DataType = {
if (f.isDefinedAt(this)) f(this) else this
}

final override private[sql] def defaultConcreteType: DataType = this

override private[sql] def acceptsType(other: DataType): Boolean = sameType(other)
Expand Down
12 changes: 12 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ case class MapType(keyType: DataType, valueType: DataType, valueContainsNull: Bo
override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f)
}

override private[spark] def transformRecursively(
f: PartialFunction[DataType, DataType]): DataType = {
if (f.isDefinedAt(this)) {
f(this)
} else {
MapType(
keyType.transformRecursively(f),
valueType.transformRecursively(f),
valueContainsNull)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory
* The id of collation for this StringType.
*/
@Stable
class StringType private (val collationId: Int) extends AtomicType with Serializable {
class StringType private[sql] (val collationId: Int) extends AtomicType with Serializable {

/**
* Support for Binary Equality implies that strings are considered equal only if they are byte
Expand Down Expand Up @@ -75,7 +75,14 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ
*/
override def typeName: String =
if (isUTF8BinaryCollation) "string"
else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}"
else s"string collate $collationName"

override def toString: String =
if (isUTF8BinaryCollation) "StringType"
else s"StringType($collationName)"

private[sql] def collationName: String =
CollationFactory.fetchCollation(collationId).collationName

// Due to backwards compatibility and compatibility with other readers
// all string types are serialized in json as regular strings and
Expand Down
12 changes: 12 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,18 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
f(this) || fields.exists(field => field.dataType.existsRecursively(f))
}

override private[spark] def transformRecursively(
f: PartialFunction[DataType, DataType]): DataType = {
if (f.isDefinedAt(this)) {
return f(this)
}

val newFields = fields.map { field =>
field.copy(dataType = field.dataType.transformRecursively(f))
}
StructType(newFields)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveAliases ::
ResolveSubquery ::
ResolveSubqueryColumnAliases ::
ResolveDefaultStringTypes ::
ResolveWindowOrder ::
ResolveWindowFrame ::
ResolveNaturalAndUsingJoin ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveS
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType}
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
import org.apache.spark.sql.util.SchemaUtils

/**
Expand Down Expand Up @@ -93,13 +93,6 @@ object CollationTypeCoercion {
val Seq(newStr, newPad) = collateToSingleType(Seq(str, pad))
stringPadExpr.withNewChildren(Seq(newStr, len, newPad))

case raiseError: RaiseError =>
val newErrorParams = raiseError.errorParms.dataType match {
case MapType(StringType, StringType, _) => raiseError.errorParms
case _ => Cast(raiseError.errorParms, MapType(StringType, StringType))
}
raiseError.withNewChildren(Seq(raiseError.errorClass, newErrorParams))

case framelessOffsetWindow @ (_: Lag | _: Lead) =>
val Seq(input, offset, default) = framelessOffsetWindow.children
val Seq(newInput, newDefault) = collateToSingleType(Seq(input, default))
Expand Down Expand Up @@ -219,6 +212,11 @@ object CollationTypeCoercion {
*/
private def findLeastCommonStringType(expressions: Seq[Expression]): Option[StringType] = {
if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) {
// if there are no collated types we don't need to do anything
return None
} else if (ResolveDefaultStringTypes.needsResolution(expressions)) {
// if any of the strings types are still not resolved
// we need to wait for them to be resolved first
return None
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.types.{DataType, StringType}

/**
* Resolves default string types in queries and commands. For queries, the default string type is
* determined by the session's default string type. For DDL, the default string type is the
* default type of the object (table -> schema -> catalog). However, this is not implemented yet.
* So, we will just use UTF8_BINARY for now.
*/
object ResolveDefaultStringTypes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val newPlan = apply0(plan)
if (plan.ne(newPlan)) {
// Due to how tree transformations work and StringType object being equal to
// StringType("UTF8_BINARY"), we need to transform the plan twice
// to ensure the correct results for occurrences of default string type.
val finalPlan = apply0(newPlan)
RuleExecutor.forceAdditionalIteration(finalPlan)
finalPlan
} else {
newPlan
}
}

private def apply0(plan: LogicalPlan): LogicalPlan = {
if (isDDLCommand(plan)) {
transformDDL(plan)
} else {
transformPlan(plan, sessionDefaultStringType)
}
}

/**
* Returns whether any of the given `plan` needs to have its
* default string type resolved.
*/
def needsResolution(plan: LogicalPlan): Boolean = {
if (!isDDLCommand(plan) && isDefaultSessionCollationUsed) {
return false
}

plan.exists(node => needsResolution(node.expressions))
}

/**
* Returns whether any of the given `expressions` needs to have its
* default string type resolved.
*/
def needsResolution(expressions: Seq[Expression]): Boolean = {
expressions.exists(needsResolution)
}

/**
* Returns whether the given `expression` needs to have its
* default string type resolved.
*/
def needsResolution(expression: Expression): Boolean = {
expression.exists(e => transformExpression.isDefinedAt(e))
}

private def isDefaultSessionCollationUsed: Boolean = conf.defaultStringType == StringType

/**
* Returns the default string type that should be used in a given DDL command (for now always
* UTF8_BINARY).
*/
private def stringTypeForDDLCommand(table: LogicalPlan): StringType =
StringType("UTF8_BINARY")

/** Returns the session default string type */
private def sessionDefaultStringType: StringType =
StringType(conf.defaultStringType.collationId)

private def isDDLCommand(plan: LogicalPlan): Boolean = plan exists {
case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true
case _ => isCreateOrAlterPlan(plan)
}

private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match {
case _: V1CreateTablePlan | _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true
case _ => false
}

private def transformDDL(plan: LogicalPlan): LogicalPlan = {
val newType = stringTypeForDDLCommand(plan)

plan resolveOperators {
case p if isCreateOrAlterPlan(p) =>
transformPlan(p, newType)

case addCols: AddColumns =>
addCols.copy(columnsToAdd = replaceColumnTypes(addCols.columnsToAdd, newType))

case replaceCols: ReplaceColumns =>
replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, newType))

case alter: AlterColumn
if alter.dataType.isDefined && hasDefaultStringType(alter.dataType.get) =>
alter.copy(dataType = Some(replaceDefaultStringType(alter.dataType.get, newType)))
}
}

/**
* Transforms the given plan, by transforming all expressions in its operators to use the given
* new type instead of the default string type.
*/
private def transformPlan(plan: LogicalPlan, newType: StringType): LogicalPlan = {
plan resolveExpressionsUp { expression =>
transformExpression
.andThen(_.apply(newType))
.applyOrElse(expression, identity[Expression])
}
}

/**
* Transforms the given expression, by changing all default string types to the given new type.
*/
private def transformExpression: PartialFunction[Expression, StringType => Expression] = {
case columnDef: ColumnDefinition if hasDefaultStringType(columnDef.dataType) =>
newType => columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType))

case cast: Cast if hasDefaultStringType(cast.dataType) =>
newType => cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType))

case Literal(value, dt) if hasDefaultStringType(dt) =>
newType => Literal(value, replaceDefaultStringType(dt, newType))
}

private def hasDefaultStringType(dataType: DataType): Boolean =
dataType.existsRecursively(isDefaultStringType)

private def isDefaultStringType(dataType: DataType): Boolean = {
dataType match {
case st: StringType =>
// should only return true for StringType object and not StringType("UTF8_BINARY")
st.eq(StringType) || st.isInstanceOf[TemporaryStringType]
case _ => false
}
}

private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = {
dataType.transformRecursively {
case currentType: StringType if isDefaultStringType(currentType) =>
if (currentType == newType) {
TemporaryStringType()
} else {
newType
}
}
}

private def replaceColumnTypes(
colTypes: Seq[QualifiedColType],
newType: StringType): Seq[QualifiedColType] = {
colTypes.map {
case colWithDefault if hasDefaultStringType(colWithDefault.dataType) =>
val replaced = replaceDefaultStringType(colWithDefault.dataType, newType)
colWithDefault.copy(dataType = replaced)

case col => col
}
}
}

case class TemporaryStringType() extends StringType(1) {
override def toString: String = s"TemporaryStringType($collationId)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess
object ResolveInlineTables extends Rule[LogicalPlan] with EvalHelper {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsWithPruning(AlwaysProcess.fn, ruleId) {
case table: UnresolvedInlineTable if table.expressionsResolved =>
case table: UnresolvedInlineTable if canResolveTable(table) =>
EvaluateUnresolvedInlineTable.evaluateUnresolvedInlineTable(table)
}
}

private def canResolveTable(table: UnresolvedInlineTable): Boolean = {
table.expressionsResolved && !ResolveDefaultStringTypes.needsResolution(table)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ abstract class TypeCoercionHelper {
if conf.concatBinaryAsString ||
!children.map(_.dataType).forall(_ == BinaryType) =>
val newChildren = c.children.map { e =>
implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e)
implicitCast(e, StringType).getOrElse(e)
}
c.copy(children = newChildren)
case other => other
Expand Down Expand Up @@ -465,7 +465,7 @@ abstract class TypeCoercionHelper {
if (conf.eltOutputAsString ||
!children.tail.map(_.dataType).forall(_ == BinaryType)) {
children.tail.map { e =>
implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e)
implicitCast(e, StringType).getOrElse(e)
}
} else {
children.tail
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCollation}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType:
override def foldable: Boolean = false
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation, MapType(StringType, StringType))
Seq(StringTypeWithCollation, AbstractMapType(StringTypeWithCollation, StringTypeWithCollation))

override def left: Expression = errorClass
override def right: Expression = errorParms
Expand Down
Loading