Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL] Support map_concat spark function #5093

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ object CHExpressionUtil {
SPARK_PARTITION_ID -> DefaultValidator(),
URL_DECODE -> DefaultValidator(),
SKEWNESS -> DefaultValidator(),
BIT_LENGTH -> DefaultValidator()
BIT_LENGTH -> DefaultValidator(),
MAP_CONCAT -> DefaultValidator()
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -581,4 +581,14 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest {
}
}

test("Test map_concat function") {
runQueryAndCompare("select map_concat(map(4, 6), map(7, 8))") {
checkOperatorMatch[ProjectExecTransformer]
}
runQueryAndCompare(
"select map_concat(map(l_returnflag, l_comment), map(l_shipmode, l_comment)) " +
"from lineitem limit 1") {
checkOperatorMatch[ProjectExecTransformer]
}
}
}
6 changes: 3 additions & 3 deletions docs/velox-backend-support-progress.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ Gluten supports 199 functions. (Draw to right to see all data types)
| filter | filter | filter | | | | | | | | | | | | | | | | | | | | |
| flatten | flatten | | | | | | | | | | | | | | | | | | | | | |
| map | map | map | S | | | | | | | | | | | | | | | | | | | |
| map_concat | map_concat | | | | | | | | | | | | | | | | | | | | | |
| map_concat | map_concat | map_concat | S | | | | | | | | | | | | | | | | | | | |
| map_entries | map_entries | | | | | | | | | | | | | | | | | | | | | |
| map_filter | map_filter | map_filter | | | | | | | | | | | | | | | | | | | | |
| get_map_value | | element_at | S | | | | | | | | | | | | | | | | | S | | |
Expand Down Expand Up @@ -416,7 +416,7 @@ Gluten supports 199 functions. (Draw to right to see all data types)
| java_method | | | | | | | | | | | | | | | | | | | | | | |
| least | least | least | S | | | | | | S | S | S | S | S | | | | | | | | | |
| md5 | md5 | | S | | | S | | | | | | | | | | | | | | | | |
| monotonically_increasing_id | | | S | | | | | | | | | | | | | | | | | | | |
| monotonically_increasing_id | | | S | | | | | | | | | | | | | | | | | | | |
| nanvl | | | S | | | | | | | | | | | | | | | | | | | |
| nvl | | | | | | | | | | | | | | | | | | | | | | |
| nvl2 | | | | | | | | | | | | | | | | | | | | | | |
Expand All @@ -428,4 +428,4 @@ Gluten supports 199 functions. (Draw to right to see all data types)
| spark_partition_id | | | S | | | | | | | | | | | | | | | | | | | |
| stack | | | | | | | | | | | | | | | | | | | | | | |
| xxhash64 | xxhash64 | xxhash64 | | | | | | | | | | | | | | | | | | | | |
| uuid | uuid | uuid | S | | | | | | | | | | | | | | | | | | | |
| uuid | uuid | uuid | S | | | | | | | | | | | | | | | | | | | |
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ object ExpressionMappings {
Sig[MapFromArrays](MAP_FROM_ARRAYS),
Sig[MapEntries](MAP_ENTRIES),
Sig[StringToMap](STR_TO_MAP),
Sig[MapConcat](MAP_CONCAT),
// Struct functions
Sig[GetStructField](GET_STRUCT_FIELD),
Sig[CreateNamedStruct](NAMED_STRUCT),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ class VeloxTestSettings extends BackendTestSettings {
// blocked by Velox-5768
.exclude("aggregate function - array for primitive type containing null")
.exclude("aggregate function - array for non-primitive type")
// Velox by default removes duplicates from map_concat and does not throw exception
.exclude("map_concat function")
enableSuite[GlutenDataFrameTungstenSuite]
enableSuite[GlutenDataFrameSetOperationsSuite]
// Result depends on the implementation for nondeterministic expression rand.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,8 @@ class VeloxTestSettings extends BackendTestSettings {
// blocked by Velox-5768
.exclude("aggregate function - array for primitive type containing null")
.exclude("aggregate function - array for non-primitive type")
// Velox by default removes duplicates from map_concat and does not throw exception
.exclude("map_concat function")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,8 @@ class VeloxTestSettings extends BackendTestSettings {
// blocked by Velox-5768
.exclude("aggregate function - array for primitive type containing null")
.exclude("aggregate function - array for non-primitive type")
// Velox by default removes duplicates from map_concat and does not throw exception
.exclude("map_concat function")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,121 @@
*/
package org.apache.spark.sql

class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenSQLTestsTrait {}
import org.apache.spark.sql.functions.{lit, map_concat}
import org.apache.spark.sql.types.{IntegerType, MapType, StringType, StructField, StructType}

class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenSQLTestsTrait {

testGluten("map_concat function") {
import testImplicits._
val df1 = Seq(
(Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)),
(Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)),
(null, Map[Int, Int](3 -> 300, 4 -> 400))
).toDF("map1", "map2")

val expected1a = Seq(
Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)),
Row(Map(1 -> 400, 2 -> 200, 3 -> 300)),
Row(null)
)

// Velox by default handles duplicate values and behavior is same as SQLConf.MapKeyDedupPolicy.LAST_WIN
checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a)
checkAnswer(df1.select(map_concat($"map1", $"map2")), expected1a)

// map_concat arguments should be >= 2 in Velox
intercept[Exception](df1.selectExpr("map_concat(map1)").collect())
intercept[Exception](df1.select(map_concat($"map1")).collect())

val df2 = Seq(
(
Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200),
Map[String, Int]("3" -> 300, "4" -> 400)
)
).toDF("map1", "map2")

val expected2 = Seq(Row(Map()))

// map_concat with 0 arguments falls back to spark in validation
checkAnswer(df2.selectExpr("map_concat()"), expected2)
checkAnswer(df2.select(map_concat()), expected2)

val df3 = {
val schema = StructType(
StructField("map1", MapType(StringType, IntegerType, true), false) ::
StructField("map2", MapType(StringType, IntegerType, false), false) :: Nil
)
val data = Seq(
Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 3, "d" -> 4)),
Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, "d" -> 4))
)
spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
}

val expected3 = Seq(
Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)),
Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4))
)

checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3)
checkAnswer(df3.select(map_concat($"map1", $"map2")), expected3)

// Data type mismatch is handled by spark in analyze phase
checkError(
exception = intercept[AnalysisException] {
df2.selectExpr("map_concat(map1, map2)").collect()
},
errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
sqlState = None,
parameters = Map(
"sqlExpr" -> "\"map_concat(map1, map2)\"",
"dataType" -> "(\"MAP<ARRAY<INT>, INT>\" or \"MAP<STRING, INT>\")",
"functionName" -> "`map_concat`"),
context = ExpectedContext(
fragment = "map_concat(map1, map2)",
start = 0,
stop = 21)
)

checkError(
exception = intercept[AnalysisException] {
df2.select(map_concat($"map1", $"map2")).collect()
},
errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
sqlState = None,
parameters = Map(
"sqlExpr" -> "\"map_concat(map1, map2)\"",
"dataType" -> "(\"MAP<ARRAY<INT>, INT>\" or \"MAP<STRING, INT>\")",
"functionName" -> "`map_concat`")
)

checkError(
exception = intercept[AnalysisException] {
df2.selectExpr("map_concat(map1, 12)").collect()
},
errorClass = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES",
sqlState = None,
parameters = Map(
"sqlExpr" -> "\"map_concat(map1, 12)\"",
"dataType" -> "[\"MAP<ARRAY<INT>, INT>\", \"INT\"]",
"functionName" -> "`map_concat`"),
context = ExpectedContext(
fragment = "map_concat(map1, 12)",
start = 0,
stop = 19)
)

checkError(
exception = intercept[AnalysisException] {
df2.select(map_concat($"map1", lit(12))).collect()
},
errorClass = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES",
sqlState = None,
parameters = Map(
"sqlExpr" -> "\"map_concat(map1, 12)\"",
"dataType" -> "[\"MAP<ARRAY<INT>, INT>\", \"INT\"]",
"functionName" -> "`map_concat`")
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ object ExpressionNames {
final val MAP_FROM_ARRAYS = "map_from_arrays"
final val MAP_ENTRIES = "map_entries"
final val STR_TO_MAP = "str_to_map"
final val MAP_CONCAT = "map_concat"

// struct functions
final val GET_STRUCT_FIELD = "get_struct_field"
Expand Down
Loading