diff --git a/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/GlutenPlanManager.scala b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/GlutenPlanManager.scala index d7277f084b7..174fecf4b5f 100644 --- a/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/GlutenPlanManager.scala +++ b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/GlutenPlanManager.scala @@ -19,7 +19,7 @@ package org.apache.kyuubi.sql import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.UserDefinedExpression +import org.apache.spark.sql.catalyst.expressions.{ArrayContains, ArrayIntersect, ArraySort, Bin, Contains, EndsWith, LastDay, MakeDate, Overlay, Rand, Randn, Size, SortArray, StartsWith, ToUnixTimestamp, UnaryMinus, UnixTimestamp, UserDefinedExpression} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.{ObjectHashAggregateExec, SortAggregateExec} @@ -56,7 +56,6 @@ object GlutenPlanAnalysis extends Rule[SparkPlan] { if !p.relation.fileFormat.isInstanceOf[ParquetFileFormat] => true case _: RowDataSourceScanExec | - _: UserDefinedExpression | _: CartesianProductExec | _: ShuffleExchangeExec | _: ObjectHashAggregateExec | @@ -67,6 +66,30 @@ object GlutenPlanAnalysis extends Rule[SparkPlan] { _: SampleExec | _: BroadcastNestedLoopJoinExec => true + case p: SparkPlan + if p.expressions.exists(e => + e.exists { + case _: UserDefinedExpression | + _: UnaryMinus | + _: Bin | + _: Contains | + _: StartsWith | + _: EndsWith | + _: Overlay | + _: Rand | + _: Randn | + _: ArrayContains | + _: ArrayIntersect | + _: ArraySort | + _: SortArray | + _: Size | + _: LastDay | + _: MakeDate | + _: ToUnixTimestamp | + _: UnixTimestamp => true + case _ => false + }) => + true // TODO check whether the plan contains unsupported expressions }.size check(count) diff --git a/extensions/spark/kyuubi-extension-spark-3-3/src/test/scala/org/apache/kyuubi/sql/GlutenPlanManagerSuite.scala b/extensions/spark/kyuubi-extension-spark-3-3/src/test/scala/org/apache/kyuubi/sql/GlutenPlanManagerSuite.scala index 60039cce77c..dc18290ebea 100644 --- a/extensions/spark/kyuubi-extension-spark-3-3/src/test/scala/org/apache/kyuubi/sql/GlutenPlanManagerSuite.scala +++ b/extensions/spark/kyuubi-extension-spark-3-3/src/test/scala/org/apache/kyuubi/sql/GlutenPlanManagerSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.KyuubiSparkSQLExtensionTest class GlutenPlanManagerSuite extends KyuubiSparkSQLExtensionTest { - test("Kyuubi Extension fast fail if over un-supported operator threshold") { + test("Kyuubi Extension fast fail with Plan if over un-supported operator threshold") { withSQLConf( KyuubiSQLConf.GLUTEN_FALLBACK_OPERATOR_THRESHOLD.key -> "1") { withTable("gluten_tmp_1") { @@ -36,4 +36,17 @@ class GlutenPlanManagerSuite extends KyuubiSparkSQLExtensionTest { } } } + + test("Kyuubi Extension fast fail with Expression if over un-supported operator threshold") { + withSQLConf(KyuubiSQLConf.GLUTEN_FALLBACK_OPERATOR_THRESHOLD.key -> "1") { + assertThrows[TooMuchGlutenUnsupportedOperationException] { + sql("SELECT rand(100)").collect() + } + + spark.udf.register("str_len", (s: String) => s.length) + assertThrows[TooMuchGlutenUnsupportedOperationException] { + sql("SELECT str_len('123')").collect() + } + } + } }