Skip to content

Commit

Permalink
[SPARK-49358][SQL] Mode expression for map types with collated strings
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Introduce support for collated string in map types for `mode` expression.

### Why are the changes needed?
Complete complex type handling for `mode` expression.

### Does this PR introduce _any_ user-facing change?
Yes, `mode` expression can now handle map types with collated strings.

### How was this patch tested?
New tests in `CollationSQLExpressionsSuite`.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48326 from uros-db/mode-map.

Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
uros-db authored and MaxGekk committed Oct 3, 2024
1 parent 036db74 commit 38f067d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 78 deletions.
5 changes: 0 additions & 5 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1016,11 +1016,6 @@
"The input of <functionName> can't be <dataType> type data."
]
},
"UNSUPPORTED_MODE_DATA_TYPE" : {
"message" : [
"The <mode> does not support the <child> data type, because there is a \"MAP\" type with keys and/or values that have collated sub-fields."
]
},
"UNSUPPORTED_UDF_INPUT_TYPE" : {
"message" : [
"UDFs do not support '<dataType>' as an input data type."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, UnsafeRowUtils}
import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, MapData, UnsafeRowUtils}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, MapType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -52,24 +52,6 @@ case class Mode(

override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def checkInputDataTypes(): TypeCheckResult = {
// TODO: SPARK-49358: Mode expression for map type with collated fields
if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
!child.dataType.existsRecursively(f => f.isInstanceOf[MapType] &&
!UnsafeRowUtils.isBinaryStable(f))) {
/*
* The Mode class uses collation awareness logic to handle string data.
* All complex types except MapType with collated fields are supported.
*/
super.checkInputDataTypes()
} else {
TypeCheckResult.DataTypeMismatch("UNSUPPORTED_MODE_DATA_TYPE",
messageParameters =
Map("child" -> toSQLType(child.dataType),
"mode" -> toSQLId(prettyName)))
}
}

override def prettyName: String = "mode"

override def update(
Expand Down Expand Up @@ -115,6 +97,7 @@ case class Mode(
case st: StructType =>
processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields))
case at: ArrayType => processArrayTypeWithBuffer(at, data.asInstanceOf[ArrayData])
case mt: MapType => processMapTypeWithBuffer(mt, data.asInstanceOf[MapData])
case st: StringType =>
CollationFactory.getCollationKey(data.asInstanceOf[UTF8String], st.collationId)
case _ =>
Expand All @@ -140,6 +123,16 @@ case class Mode(
collationAwareTransform(data.get(i, a.elementType), a.elementType))
}

private def processMapTypeWithBuffer(mt: MapType, data: MapData): Map[Any, Any] = {
val transformedKeys = (0 until data.numElements()).map { i =>
collationAwareTransform(data.keyArray().get(i, mt.keyType), mt.keyType)
}
val transformedValues = (0 until data.numElements()).map { i =>
collationAwareTransform(data.valueArray().get(i, mt.valueType), mt.valueType)
}
transformedKeys.zip(transformedValues).toMap
}

override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
if (buffer.isEmpty) {
return null
Expand All @@ -157,8 +150,7 @@ case class Mode(
*
* The new map is then used in the rest of the Mode evaluation logic.
*
* It is expected to work for all simple and complex types with
* collated fields, except for MapType (temporarily).
* It is expected to work for all simple and complex types with collated fields.
*/
val collationAwareBuffer = getCollationAwareBuffer(child.dataType, buffer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ package org.apache.spark.sql

import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Locale

import scala.collection.immutable.Seq

import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException, SparkThrowable}
import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
Expand Down Expand Up @@ -1924,30 +1923,14 @@ class CollationSQLExpressionsSuite
}
}

test("Support mode expression with collated in recursively nested struct with map with keys") {
test("Support mode for string expression with collated complex type - nested map") {
case class ModeTestCase(collationId: String, bufferValues: Map[String, Long], result: String)
Seq(
ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"),
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"),
ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}")
).foreach { t1 =>
def checkThisError(t: ModeTestCase, query: String): Any = {
val c = s"STRUCT<m1: MAP<STRING COLLATE ${t.collationId.toUpperCase(Locale.ROOT)}, INT>>"
val c1 = s"\"${c}\""
checkError(
exception = intercept[SparkThrowable] {
sql(query).collect()
},
condition = "DATATYPE_MISMATCH.UNSUPPORTED_MODE_DATA_TYPE",
parameters = Map(
("sqlExpr", "\"mode(i)\""),
("child", c1),
("mode", "`mode`")),
queryContext = Seq(ExpectedContext("mode(i)", 18, 24)).toArray
)
}

def getValuesToAdd(t: ModeTestCase): String = {
val valuesToAdd = t.bufferValues.map {
case (elt, numRepeats) =>
Expand All @@ -1964,41 +1947,12 @@ class CollationSQLExpressionsSuite
sql(s"INSERT INTO ${tableName} VALUES ${getValuesToAdd(t1)}")
val query = "SELECT lower(cast(mode(i).m1 as string))" +
s" FROM ${tableName}"
if (t1.collationId == "utf8_binary") {
checkAnswer(sql(query), Row(t1.result))
} else {
checkThisError(t1, query)
}
val queryResult = sql(query)
checkAnswer(queryResult, Row(t1.result))
}
}
}

test("UDT with collation - Mode (throw exception)") {
case class ModeTestCase(collationId: String, bufferValues: Map[String, Long], result: String)
Seq(
ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
).foreach { t1 =>
checkError(
exception = intercept[SparkIllegalArgumentException] {
Mode(
child = Literal.create(null,
MapType(StringType(t1.collationId), IntegerType))
).collationAwareTransform(
data = Map.empty[String, Any],
dataType = MapType(StringType(t1.collationId), IntegerType)
)
},
condition = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS",
parameters = Map(
"expression" -> "\"mode(NULL)\"",
"functionName" -> "\"MODE\"",
"dataType" -> s"\"MAP<STRING COLLATE ${t1.collationId.toUpperCase()}, INT>\"")
)
}
}

test("SPARK-48430: Map value extraction with collations") {
for {
collateKey <- Seq(true, false)
Expand Down

0 comments on commit 38f067d

Please sign in to comment.