diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala index dd9fbed1c690..9aa01076c1d7 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.execution.{ScalarSubquery, _} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec -import org.apache.spark.sql.hive.HiveSimpleUDFTransformer +import org.apache.spark.sql.hive.HiveUDFTransformer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -109,8 +109,8 @@ object ExpressionConverter extends SQLConfHelper with Logging { return replacePythonUDFWithExpressionTransformer(p, attributeSeq, expressionsMap) case s: ScalaUDF => return replaceScalaUDFWithExpressionTransformer(s, attributeSeq, expressionsMap) - case _ if HiveSimpleUDFTransformer.isHiveSimpleUDF(expr) => - return HiveSimpleUDFTransformer.replaceWithExpressionTransformer(expr, attributeSeq) + case _ if HiveUDFTransformer.isHiveUDF(expr) => + return HiveUDFTransformer.replaceWithExpressionTransformer(expr, attributeSeq) case _ => } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveSimpleUDFTransformer.scala b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala similarity index 65% rename from gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveSimpleUDFTransformer.scala rename to gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala index 1672648b308c..3ea448b48d3a 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveSimpleUDFTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala @@ -22,10 +22,10 @@ import org.apache.spark.sql.catalyst.expressions._ import java.util.Locale -object HiveSimpleUDFTransformer { - def isHiveSimpleUDF(expr: Expression): Boolean = { +object HiveUDFTransformer { + def isHiveUDF(expr: Expression): Boolean = { expr match { - case _: HiveSimpleUDF => true + case _: HiveSimpleUDF | _: HiveGenericUDF => true case _ => false } } @@ -33,23 +33,26 @@ object HiveSimpleUDFTransformer { def replaceWithExpressionTransformer( expr: Expression, attributeSeq: Seq[Attribute]): ExpressionTransformer = { - if (!isHiveSimpleUDF(expr)) { - throw new UnsupportedOperationException(s"Expression $expr is not a HiveSimpleUDF") + val udfName = expr match { + case s: HiveSimpleUDF => + s.name.stripPrefix("default.") + case g: HiveGenericUDF => + g.name.stripPrefix("default.") + case _ => + throw new UnsupportedOperationException( + s"Expression $expr is not a HiveSimpleUDF or HiveGenericUDF") } - val udf = expr.asInstanceOf[HiveSimpleUDF] - val substraitExprName = - UDFMappings.hiveUDFMap.get(udf.name.stripPrefix("default.").toLowerCase(Locale.ROOT)) - substraitExprName match { + UDFMappings.hiveUDFMap.get(udfName.toLowerCase(Locale.ROOT)) match { case Some(name) => GenericExpressionTransformer( name, - ExpressionConverter.replaceWithExpressionTransformer(udf.children, attributeSeq), - udf) + ExpressionConverter.replaceWithExpressionTransformer(expr.children, attributeSeq), + expr) case _ => throw new UnsupportedOperationException( - s"Not supported hive simple udf:$udf" - + s" name:${udf.name} hiveUDFMap:${UDFMappings.hiveUDFMap}") + s"Not supported hive udf:$expr" + + s" name:$udfName hiveUDFMap:${UDFMappings.hiveUDFMap}") } } }