From 05f40b5a2904a38045b82b387cde23af7802a90c Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Thu, 24 Oct 2024 08:31:56 -0500 Subject: [PATCH] Put DF_UDF plugin code into the main uber jar. (#11634) Signed-off-by: Robert (Bobby) Evans --- df_udf/README.md => DF_UDF_README.md | 35 ++- df_udf/pom.xml | 88 -------- .../nvidia/spark/rapids}/functionsSuite.scala | 2 +- .../spark/sql/nvidia/SparkTestBase.scala | 0 pom.xml | 1 - scala2.13/df_udf/pom.xml | 88 -------- scala2.13/pom.xml | 1 - .../scala/com/nvidia/spark/DFUDFPlugin.scala | 14 +- .../scala/com/nvidia/spark/functions.scala | 83 +++---- .../nvidia/spark/rapids/DFUDFPluginAPI.scala | 27 +++ .../com/nvidia/spark/rapids/Functions.scala | 193 ++++++++++++++++ .../com/nvidia/spark/rapids/ShimLoader.scala | 8 + .../com/nvidia/spark/DFUDFPluginImpl.scala | 33 +++ .../com/nvidia/spark/FunctionsImpl.scala | 212 ++++++++++++++++++ .../com/nvidia/spark/rapids/Plugin.scala | 7 +- .../com/nvidia/spark/rapids/RapidsConf.scala | 8 + .../spark/sql/nvidia/LogicalPlanRules.scala | 11 +- .../spark/sql/nvidia/dataframe_udfs.scala | 0 .../apache/spark/sql/nvidia/DFUDFShims.scala | 0 .../apache/spark/sql/nvidia/DFUDFShims.scala | 0 20 files changed, 562 insertions(+), 249 deletions(-) rename df_udf/README.md => DF_UDF_README.md (78%) delete mode 100644 df_udf/pom.xml rename {df_udf/src/test/scala/com/nvidia/spark => integration_tests/src/test/scala/com/nvidia/spark/rapids}/functionsSuite.scala (99%) rename {df_udf => integration_tests}/src/test/scala/org/apache/spark/sql/nvidia/SparkTestBase.scala (100%) delete mode 100644 scala2.13/df_udf/pom.xml rename {df_udf => sql-plugin-api}/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala (83%) rename {df_udf => sql-plugin-api}/src/main/scala/com/nvidia/spark/functions.scala (83%) create mode 100644 sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/DFUDFPluginAPI.scala create mode 100644 sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/Functions.scala create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/DFUDFPluginImpl.scala create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/FunctionsImpl.scala rename {df_udf => sql-plugin}/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala (85%) rename {df_udf => sql-plugin}/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala (100%) rename {df_udf => sql-plugin}/src/main/spark320/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala (100%) rename {df_udf => sql-plugin}/src/main/spark400/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala (100%) diff --git a/df_udf/README.md b/DF_UDF_README.md similarity index 78% rename from df_udf/README.md rename to DF_UDF_README.md index 0226c365a42..a669c87f258 100644 --- a/df_udf/README.md +++ b/DF_UDF_README.md @@ -15,9 +15,14 @@ commands. ## Setup -To do this include com.nvidia:df_udf_plugin as a dependency for your project and also include it on the -classpath for your Apache Spark environment. Then include `com.nvidia.spark.DFUDFPlugin` in the config -`spark.sql.extensions`. Now you can implement a UDF in terms of Dataframe operations. +The dataframe UDF plugin is packaged in the same jar as the RAPIDS Accelerator for Apache Spark. This jar will need to +be added as a compile time dependency for code that wants to use this feature as well as adding the jar to your Spark +classpath just like you would do for GPU acceleration. + +If you plan to not use the GPU accelerated processing, but still want dataframe UDF support on CPU applications then +add `com.nvidia.spark.DFUDFPlugin` to the `spark.sql.extensions` config. If you do use GPU accelerated processing +the RAPIDS Plugin will enable this automatically. You don't need to set the `spark.sql.extensions` config, but it +won't hurt anything if you do add it. Now you can implement a UDF in terms of Dataframe operations. ## Usage @@ -48,6 +53,28 @@ Seq(Array(1L, 2L, 3L)).toDF("data").selectExpr("sum_array(data) as result").show +------+ ``` +Java APIs are also supported and should work the same as Spark's UDFs + +```java +import com.nvidia.spark.functions.df_udf + +import org.apache.spark.sql.*; +import org.apache.spark.sql.api.java.UDF2; +import org.apache.spark.sql.expressions.UserDefinedFunction; + + +UserDefinedFunction myAdd = df_udf((Column lhs, Column rhs) -> lhs + rhs) +spark.udf().register("myadd", myAdd) + +spark.sql("SELECT myadd(1, 1) as r").show(); +// +--+ +// | r| +// +--+ +// | 2| +// +--+ + +``` + ## Type Checks DataFrame APIs do not provide type safety when writing the code and that is the same here. There are no builtin type @@ -87,4 +114,4 @@ at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$$anonfun$apply$ at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$$anonfun$apply$57$$anonfun$applyOrElse$234.applyOrElse(Analyzer.scala:3654) ``` -We hope to add optional type checks in the future. \ No newline at end of file +We hope to add optional type checks in the future. diff --git a/df_udf/pom.xml b/df_udf/pom.xml deleted file mode 100644 index 39f33880f34..00000000000 --- a/df_udf/pom.xml +++ /dev/null @@ -1,88 +0,0 @@ - - - - 4.0.0 - - com.nvidia - rapids-4-spark-shim-deps-parent_2.12 - 24.12.0-SNAPSHOT - ../shim-deps/pom.xml - - df_udf_plugin_2.12 - UDFs implemented in SQL/Dataframe - UDFs for Apache Spark implemented in SQL/Dataframe - 24.12.0-SNAPSHOT - - - df_udf - - **/* - package - ${project.build.outputDirectory}/df_udf-version-info.properties - - - - - org.scala-lang - scala-library - - - org.scalatest - scalatest_${scala.binary.version} - test - - - org.apache.spark - spark-sql_${scala.binary.version} - ${spark.test.version} - - - - - - - - org.apache.maven.plugins - maven-surefire-plugin - - true - - - - net.alchim31.maven - scala-maven-plugin - - - org.scalatest - scalatest-maven-plugin - - - org.apache.rat - apache-rat-plugin - - - - - - - ${project.build.directory}/extra-resources - - - - diff --git a/df_udf/src/test/scala/com/nvidia/spark/functionsSuite.scala b/integration_tests/src/test/scala/com/nvidia/spark/rapids/functionsSuite.scala similarity index 99% rename from df_udf/src/test/scala/com/nvidia/spark/functionsSuite.scala rename to integration_tests/src/test/scala/com/nvidia/spark/rapids/functionsSuite.scala index ae6d46aefdf..04e2f8a48f1 100644 --- a/df_udf/src/test/scala/com/nvidia/spark/functionsSuite.scala +++ b/integration_tests/src/test/scala/com/nvidia/spark/rapids/functionsSuite.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark +package com.nvidia.spark.rapids import com.nvidia.spark.functions._ diff --git a/df_udf/src/test/scala/org/apache/spark/sql/nvidia/SparkTestBase.scala b/integration_tests/src/test/scala/org/apache/spark/sql/nvidia/SparkTestBase.scala similarity index 100% rename from df_udf/src/test/scala/org/apache/spark/sql/nvidia/SparkTestBase.scala rename to integration_tests/src/test/scala/org/apache/spark/sql/nvidia/SparkTestBase.scala diff --git a/pom.xml b/pom.xml index 13caf83013c..893af998540 100644 --- a/pom.xml +++ b/pom.xml @@ -73,7 +73,6 @@ aggregator datagen - df_udf dist integration_tests shuffle-plugin diff --git a/scala2.13/df_udf/pom.xml b/scala2.13/df_udf/pom.xml deleted file mode 100644 index 04f7a6deb28..00000000000 --- a/scala2.13/df_udf/pom.xml +++ /dev/null @@ -1,88 +0,0 @@ - - - - 4.0.0 - - com.nvidia - rapids-4-spark-shim-deps-parent_2.13 - 24.12.0-SNAPSHOT - ../shim-deps/pom.xml - - df_udf_plugin_2.13 - UDFs implemented in SQL/Dataframe - UDFs for Apache Spark implemented in SQL/Dataframe - 24.12.0-SNAPSHOT - - - df_udf - - **/* - package - ${project.build.outputDirectory}/df_udf-version-info.properties - - - - - org.scala-lang - scala-library - - - org.scalatest - scalatest_${scala.binary.version} - test - - - org.apache.spark - spark-sql_${scala.binary.version} - ${spark.test.version} - - - - - - - - org.apache.maven.plugins - maven-surefire-plugin - - true - - - - net.alchim31.maven - scala-maven-plugin - - - org.scalatest - scalatest-maven-plugin - - - org.apache.rat - apache-rat-plugin - - - - - - - ${project.build.directory}/extra-resources - - - - diff --git a/scala2.13/pom.xml b/scala2.13/pom.xml index cca074773d2..02932dcc715 100644 --- a/scala2.13/pom.xml +++ b/scala2.13/pom.xml @@ -73,7 +73,6 @@ aggregator datagen - df_udf dist integration_tests shuffle-plugin diff --git a/df_udf/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala b/sql-plugin-api/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala similarity index 83% rename from df_udf/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala rename to sql-plugin-api/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala index 7e1c0451c8a..3322dd3fd11 100644 --- a/df_udf/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala +++ b/sql-plugin-api/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala @@ -16,16 +16,18 @@ package com.nvidia.spark +import com.nvidia.spark.rapids.{DFUDFPluginAPI, ShimLoader} + import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule class DFUDFPlugin extends (SparkSessionExtensions => Unit) { - override def apply(extensions: SparkSessionExtensions): Unit = { - extensions.injectResolutionRule(logicalPlanRules) - } + private lazy val impl: DFUDFPluginAPI = ShimLoader.newDFUDFImpl() + + override def apply(extensions: SparkSessionExtensions): Unit = + impl(extensions) - def logicalPlanRules(sparkSession: SparkSession): Rule[LogicalPlan] = { - org.apache.spark.sql.nvidia.LogicalPlanRules() - } + def logicalPlanRules(sparkSession: SparkSession): Rule[LogicalPlan] = + impl.logicalPlanRules(sparkSession) } \ No newline at end of file diff --git a/df_udf/src/main/scala/com/nvidia/spark/functions.scala b/sql-plugin-api/src/main/scala/com/nvidia/spark/functions.scala similarity index 83% rename from df_udf/src/main/scala/com/nvidia/spark/functions.scala rename to sql-plugin-api/src/main/scala/com/nvidia/spark/functions.scala index 8c8eef3f825..d3bce5af637 100644 --- a/df_udf/src/main/scala/com/nvidia/spark/functions.scala +++ b/sql-plugin-api/src/main/scala/com/nvidia/spark/functions.scala @@ -16,40 +16,38 @@ package com.nvidia.spark +import com.nvidia.spark.rapids.{Functions, ShimLoader} + import org.apache.spark.sql.Column import org.apache.spark.sql.api.java.{UDF0, UDF1, UDF10, UDF2, UDF3, UDF4, UDF5, UDF6, UDF7, UDF8, UDF9} import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.nvidia._ -import org.apache.spark.sql.types.LongType // scalastyle:off object functions { // scalastyle:on + private lazy val impl: Functions = ShimLoader.newFunctionsImpl() + /** * Defines a Scala closure of Columns as user-defined function (UDF). * By default the returned UDF is deterministic. To change it to * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. */ - def df_udf(f: Function0[Column]): UserDefinedFunction = - udf(DFUDF0(f), LongType) + def df_udf(f: Function0[Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Scala closure of Columns as user-defined function (UDF). * By default the returned UDF is deterministic. To change it to * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. */ - def df_udf(f: Function1[Column, Column]): UserDefinedFunction = - udf(DFUDF1(f), LongType) + def df_udf(f: Function1[Column, Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Scala closure of Columns as user-defined function (UDF). * By default the returned UDF is deterministic. To change it to * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. */ - def df_udf(f: Function2[Column, Column, Column]): UserDefinedFunction = - udf(DFUDF2(f), LongType) + def df_udf(f: Function2[Column, Column, Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Scala closure of Columns as user-defined function (UDF). @@ -57,7 +55,7 @@ object functions { * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: Function3[Column, Column, Column, Column]): UserDefinedFunction = - udf(DFUDF3(f), LongType) + impl.df_udf(f) /** * Defines a Scala closure of Columns as user-defined function (UDF). @@ -65,7 +63,7 @@ object functions { * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: Function4[Column, Column, Column, Column, Column]): UserDefinedFunction = - udf(DFUDF4(f), LongType) + impl.df_udf(f) /** * Defines a Scala closure of Columns as user-defined function (UDF). @@ -73,7 +71,7 @@ object functions { * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: Function5[Column, Column, Column, Column, Column, Column]): UserDefinedFunction = - udf(DFUDF5(f), LongType) + impl.df_udf(f) /** * Defines a Scala closure of Columns as user-defined function (UDF). @@ -81,8 +79,7 @@ object functions { * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: Function6[Column, Column, Column, Column, Column, Column, - Column]): UserDefinedFunction = - udf(DFUDF6(f), LongType) + Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Scala closure of Columns as user-defined function (UDF). @@ -90,8 +87,7 @@ object functions { * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: Function7[Column, Column, Column, Column, Column, Column, - Column, Column]): UserDefinedFunction = - udf(DFUDF7(f), LongType) + Column, Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Scala closure of Columns as user-defined function (UDF). @@ -99,8 +95,7 @@ object functions { * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: Function8[Column, Column, Column, Column, Column, Column, - Column, Column, Column]): UserDefinedFunction = - udf(DFUDF8(f), LongType) + Column, Column, Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Scala closure of Columns as user-defined function (UDF). @@ -108,8 +103,7 @@ object functions { * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: Function9[Column, Column, Column, Column, Column, Column, - Column, Column, Column, Column]): UserDefinedFunction = - udf(DFUDF9(f), LongType) + Column, Column, Column, Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Scala closure of Columns as user-defined function (UDF). @@ -117,8 +111,7 @@ object functions { * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: Function10[Column, Column, Column, Column, Column, Column, - Column, Column, Column, Column, Column]): UserDefinedFunction = - udf(DFUDF10(f), LongType) + Column, Column, Column, Column, Column]): UserDefinedFunction = impl.df_udf(f) ////////////////////////////////////////////////////////////////////////////////////////////// @@ -130,54 +123,43 @@ object functions { * By default the returned UDF is deterministic. To change it to nondeterministic, call the * API `UserDefinedFunction.asNondeterministic()`. */ - def df_udf(f: UDF0[Column]): UserDefinedFunction = { - udf(JDFUDF0(f), LongType) - } + def df_udf(f: UDF0[Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Java UDF instance of Columns as user-defined function (UDF). * By default the returned UDF is deterministic. To change it to nondeterministic, call the * API `UserDefinedFunction.asNondeterministic()`. */ - def df_udf(f: UDF1[Column, Column]): UserDefinedFunction = { - udf(JDFUDF1(f), LongType) - } + def df_udf(f: UDF1[Column, Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Java UDF instance of Columns as user-defined function (UDF). * By default the returned UDF is deterministic. To change it to nondeterministic, call the * API `UserDefinedFunction.asNondeterministic()`. */ - def df_udf(f: UDF2[Column, Column, Column]): UserDefinedFunction = { - udf(JDFUDF2(f), LongType) - } + def df_udf(f: UDF2[Column, Column, Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Java UDF instance of Columns as user-defined function (UDF). * By default the returned UDF is deterministic. To change it to nondeterministic, call the * API `UserDefinedFunction.asNondeterministic()`. */ - def df_udf(f: UDF3[Column, Column, Column, Column]): UserDefinedFunction = { - udf(JDFUDF3(f), LongType) - } + def df_udf(f: UDF3[Column, Column, Column, Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Java UDF instance of Columns as user-defined function (UDF). * By default the returned UDF is deterministic. To change it to nondeterministic, call the * API `UserDefinedFunction.asNondeterministic()`. */ - def df_udf(f: UDF4[Column, Column, Column, Column, Column]): UserDefinedFunction = { - udf(JDFUDF4(f), LongType) - } + def df_udf(f: UDF4[Column, Column, Column, Column, Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Java UDF instance of Columns as user-defined function (UDF). * By default the returned UDF is deterministic. To change it to nondeterministic, call the * API `UserDefinedFunction.asNondeterministic()`. */ - def df_udf(f: UDF5[Column, Column, Column, Column, Column, Column]): UserDefinedFunction = { - udf(JDFUDF5(f), LongType) - } + def df_udf(f: UDF5[Column, Column, Column, Column, Column, Column]): UserDefinedFunction = + impl.df_udf(f) /** * Defines a Java UDF instance of Columns as user-defined function (UDF). @@ -185,9 +167,7 @@ object functions { * API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: UDF6[Column, Column, Column, Column, Column, Column, - Column]): UserDefinedFunction = { - udf(JDFUDF6(f), LongType) - } + Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Java UDF instance of Columns as user-defined function (UDF). @@ -195,9 +175,7 @@ object functions { * API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: UDF7[Column, Column, Column, Column, Column, Column, - Column, Column]): UserDefinedFunction = { - udf(JDFUDF7(f), LongType) - } + Column, Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Java UDF instance of Columns as user-defined function (UDF). @@ -205,9 +183,7 @@ object functions { * API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: UDF8[Column, Column, Column, Column, Column, Column, - Column, Column, Column]): UserDefinedFunction = { - udf(JDFUDF8(f), LongType) - } + Column, Column, Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Java UDF instance of Columns as user-defined function (UDF). @@ -215,9 +191,7 @@ object functions { * API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: UDF9[Column, Column, Column, Column, Column, Column, - Column, Column, Column, Column]): UserDefinedFunction = { - udf(JDFUDF9(f), LongType) - } + Column, Column, Column, Column]): UserDefinedFunction = impl.df_udf(f) /** * Defines a Java UDF instance of Columns as user-defined function (UDF). @@ -225,8 +199,5 @@ object functions { * API `UserDefinedFunction.asNondeterministic()`. */ def df_udf(f: UDF10[Column, Column, Column, Column, Column, Column, - Column, Column, Column, Column, Column]): UserDefinedFunction = { - udf(JDFUDF10(f), LongType) - } - + Column, Column, Column, Column, Column]): UserDefinedFunction = impl.df_udf(f) } \ No newline at end of file diff --git a/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/DFUDFPluginAPI.scala b/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/DFUDFPluginAPI.scala new file mode 100644 index 00000000000..1ca16ea1873 --- /dev/null +++ b/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/DFUDFPluginAPI.scala @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +trait DFUDFPluginAPI { + def apply(extensions: SparkSessionExtensions): Unit + + def logicalPlanRules(sparkSession: SparkSession): Rule[LogicalPlan] +} diff --git a/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/Functions.scala b/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/Functions.scala new file mode 100644 index 00000000000..b9979d03186 --- /dev/null +++ b/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/Functions.scala @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.Column +import org.apache.spark.sql.api.java.{UDF0, UDF1, UDF10, UDF2, UDF3, UDF4, UDF5, UDF6, UDF7, UDF8, UDF9} +import org.apache.spark.sql.expressions.UserDefinedFunction + +trait Functions { + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function0[Column]): UserDefinedFunction + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function1[Column, Column]): UserDefinedFunction + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function2[Column, Column, Column]): UserDefinedFunction + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function3[Column, Column, Column, Column]): UserDefinedFunction + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function4[Column, Column, Column, Column, Column]): UserDefinedFunction + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function5[Column, Column, Column, Column, Column, Column]): UserDefinedFunction + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function6[Column, Column, Column, Column, Column, Column, + Column]): UserDefinedFunction + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function7[Column, Column, Column, Column, Column, Column, + Column, Column]): UserDefinedFunction + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function8[Column, Column, Column, Column, Column, Column, + Column, Column, Column]): UserDefinedFunction + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function9[Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column]): UserDefinedFunction + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function10[Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column, Column]): UserDefinedFunction + + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Java UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF0[Column]): UserDefinedFunction + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF1[Column, Column]): UserDefinedFunction + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF2[Column, Column, Column]): UserDefinedFunction + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF3[Column, Column, Column, Column]): UserDefinedFunction + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF4[Column, Column, Column, Column, Column]): UserDefinedFunction + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF5[Column, Column, Column, Column, Column, Column]): UserDefinedFunction + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF6[Column, Column, Column, Column, Column, Column, + Column]): UserDefinedFunction + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF7[Column, Column, Column, Column, Column, Column, + Column, Column]): UserDefinedFunction + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF8[Column, Column, Column, Column, Column, Column, + Column, Column, Column]): UserDefinedFunction + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF9[Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column]): UserDefinedFunction + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF10[Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column, Column]): UserDefinedFunction +} diff --git a/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala b/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala index bc35dad5372..23a9ece7468 100644 --- a/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala +++ b/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala @@ -360,6 +360,14 @@ object ShimLoader { newInstanceOf("com.nvidia.spark.rapids.InternalExclusiveModeGpuDiscoveryPlugin") } + def newFunctionsImpl(): Functions = { + ShimReflectionUtils.newInstanceOf("com.nvidia.spark.FunctionsImpl") + } + + def newDFUDFImpl(): DFUDFPluginAPI = { + ShimReflectionUtils.newInstanceOf("com.nvidia.spark.DFUDFPluginImpl") + } + def loadColumnarRDD(): Class[_] = { ShimReflectionUtils. loadClass("org.apache.spark.sql.rapids.execution.InternalColumnarRddConverter") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/DFUDFPluginImpl.scala b/sql-plugin/src/main/scala/com/nvidia/spark/DFUDFPluginImpl.scala new file mode 100644 index 00000000000..ad1511d9aaf --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/DFUDFPluginImpl.scala @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark + +import com.nvidia.spark.rapids.DFUDFPluginAPI + +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +class DFUDFPluginImpl extends DFUDFPluginAPI { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectResolutionRule(logicalPlanRules) + } + + override def logicalPlanRules(sparkSession: SparkSession): Rule[LogicalPlan] = { + org.apache.spark.sql.nvidia.LogicalPlanRules() + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/FunctionsImpl.scala b/sql-plugin/src/main/scala/com/nvidia/spark/FunctionsImpl.scala new file mode 100644 index 00000000000..7c27cb79054 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/FunctionsImpl.scala @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark + +import com.nvidia.spark.rapids.Functions + +import org.apache.spark.sql.Column +import org.apache.spark.sql.api.java.{UDF0, UDF1, UDF10, UDF2, UDF3, UDF4, UDF5, UDF6, UDF7, UDF8, UDF9} +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions.{udf => sp_udf} +import org.apache.spark.sql.nvidia._ +import org.apache.spark.sql.types.LongType + +// scalastyle:off +class FunctionsImpl extends Functions { +// scalastyle:on + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: Function0[Column]): UserDefinedFunction = + sp_udf(DFUDF0(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: Function1[Column, Column]): UserDefinedFunction = + sp_udf(DFUDF1(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: Function2[Column, Column, Column]): UserDefinedFunction = + sp_udf(DFUDF2(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: Function3[Column, Column, Column, Column]): UserDefinedFunction = + sp_udf(DFUDF3(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: Function4[Column, Column, Column, Column, Column]): UserDefinedFunction = + sp_udf(DFUDF4(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: Function5[Column, Column, Column, Column, Column, + Column]): UserDefinedFunction = sp_udf(DFUDF5(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: Function6[Column, Column, Column, Column, Column, Column, + Column]): UserDefinedFunction = sp_udf(DFUDF6(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: Function7[Column, Column, Column, Column, Column, Column, + Column, Column]): UserDefinedFunction = sp_udf(DFUDF7(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: Function8[Column, Column, Column, Column, Column, Column, + Column, Column, Column]): UserDefinedFunction = sp_udf(DFUDF8(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: Function9[Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column]): UserDefinedFunction = sp_udf(DFUDF9(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: Function10[Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column, Column]): UserDefinedFunction = sp_udf(DFUDF10(f), LongType) + + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Java UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: UDF0[Column]): UserDefinedFunction = + sp_udf(JDFUDF0(f), LongType) + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: UDF1[Column, Column]): UserDefinedFunction = + sp_udf(JDFUDF1(f), LongType) + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: UDF2[Column, Column, Column]): UserDefinedFunction = + sp_udf(JDFUDF2(f), LongType) + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: UDF3[Column, Column, Column, Column]): UserDefinedFunction = + sp_udf(JDFUDF3(f), LongType) + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: UDF4[Column, Column, Column, Column, Column]): UserDefinedFunction = + sp_udf(JDFUDF4(f), LongType) + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: UDF5[Column, Column, Column, Column, Column, + Column]): UserDefinedFunction = sp_udf(JDFUDF5(f), LongType) + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: UDF6[Column, Column, Column, Column, Column, Column, + Column]): UserDefinedFunction = sp_udf(JDFUDF6(f), LongType) + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: UDF7[Column, Column, Column, Column, Column, Column, + Column, Column]): UserDefinedFunction = sp_udf(JDFUDF7(f), LongType) + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: UDF8[Column, Column, Column, Column, Column, Column, + Column, Column, Column]): UserDefinedFunction = sp_udf(JDFUDF8(f), LongType) + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: UDF9[Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column]): UserDefinedFunction = sp_udf(JDFUDF9(f), LongType) + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + override def df_udf(f: UDF10[Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column, Column]): UserDefinedFunction = sp_udf(JDFUDF10(f), LongType) +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 3e415e23d56..e20b21da520 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -27,6 +27,7 @@ import scala.sys.process._ import scala.util.Try import ai.rapids.cudf.{Cuda, CudaException, CudaFatalException, CudfException, MemoryCleaner, NvtxColor, NvtxRange} +import com.nvidia.spark.DFUDFPlugin import com.nvidia.spark.rapids.RapidsConf.AllowMultipleJars import com.nvidia.spark.rapids.RapidsPluginUtils.buildInfoEvent import com.nvidia.spark.rapids.filecache.{FileCache, FileCacheLocalityManager, FileCacheLocalityMsg} @@ -65,6 +66,7 @@ object RapidsPluginUtils extends Logging { private val SQL_PLUGIN_NAME = classOf[SQLExecPlugin].getName private val UDF_PLUGIN_NAME = "com.nvidia.spark.udf.Plugin" + private val DFUDF_PLUGIN_NAME = classOf[DFUDFPlugin].getName private val SQL_PLUGIN_CONF_KEY = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key private val SERIALIZER_CONF_KEY = "spark.serializer" private val JAVA_SERIALIZER_NAME = classOf[JavaSerializer].getName @@ -226,9 +228,10 @@ object RapidsPluginUtils extends Logging { } def fixupConfigsOnDriver(conf: SparkConf): Unit = { + val plugins = Array(SQL_PLUGIN_NAME, UDF_PLUGIN_NAME, DFUDF_PLUGIN_NAME) // First add in the SQL executor plugin because that is what we need at a minimum if (conf.contains(SQL_PLUGIN_CONF_KEY)) { - for (pluginName <- Array(SQL_PLUGIN_NAME, UDF_PLUGIN_NAME)){ + for (pluginName <- plugins){ val previousValue = conf.get(SQL_PLUGIN_CONF_KEY).split(",").map(_.trim) if (!previousValue.contains(pluginName)) { conf.set(SQL_PLUGIN_CONF_KEY, (previousValue :+ pluginName).mkString(",")) @@ -237,7 +240,7 @@ object RapidsPluginUtils extends Logging { } } } else { - conf.set(SQL_PLUGIN_CONF_KEY, Array(SQL_PLUGIN_NAME,UDF_PLUGIN_NAME).mkString(",")) + conf.set(SQL_PLUGIN_CONF_KEY, plugins.mkString(",")) } val serializer = conf.get(SERIALIZER_CONF_KEY, JAVA_SERIALIZER_NAME) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index a0926fe4d43..ca906baffcc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -822,6 +822,12 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .booleanConf .createWithDefault(false) + val DFUDF_ENABLED = conf("spark.rapids.sql.dfudf.enabled") + .doc("When set to false, the DataFrame UDF plugin is disabled. True enables it.") + .internal() + .booleanConf + .createWithDefault(true) + val INCOMPATIBLE_OPS = conf("spark.rapids.sql.incompatibleOps.enabled") .doc("For operations that work, but are not 100% compatible with the Spark equivalent " + "set if they should be enabled by default or disabled by default.") @@ -2598,6 +2604,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val isUdfCompilerEnabled: Boolean = get(UDF_COMPILER_ENABLED) + lazy val isDfUdfEnabled: Boolean = get(DFUDF_ENABLED) + lazy val exportColumnarRdd: Boolean = get(EXPORT_COLUMNAR_RDD) lazy val shuffledHashJoinOptimizeShuffle: Boolean = get(SHUFFLED_HASH_JOIN_OPTIMIZE_SHUFFLE) diff --git a/df_udf/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala similarity index 85% rename from df_udf/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala rename to sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala index 24a123016d6..923d92572b5 100644 --- a/df_udf/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala @@ -16,6 +16,8 @@ package org.apache.spark.sql.nvidia +import com.nvidia.spark.rapids.RapidsConf + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -32,6 +34,11 @@ case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging { } } - override def apply(plan: LogicalPlan): LogicalPlan = - plan.transformExpressions(replacePartialFunc) + override def apply(plan: LogicalPlan): LogicalPlan = { + if (RapidsConf.DFUDF_ENABLED.get(plan.conf)) { + plan.transformExpressions(replacePartialFunc) + } else { + plan + } + } } diff --git a/df_udf/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala similarity index 100% rename from df_udf/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala rename to sql-plugin/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala diff --git a/df_udf/src/main/spark320/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala similarity index 100% rename from df_udf/src/main/spark320/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala rename to sql-plugin/src/main/spark320/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala diff --git a/df_udf/src/main/spark400/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala similarity index 100% rename from df_udf/src/main/spark400/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala rename to sql-plugin/src/main/spark400/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala