From a2e96f7035f4943a6f7a54c272c99693dfebcd1e Mon Sep 17 00:00:00 2001 From: Zhen Li <10524738+zhli1142015@users.noreply.github.com> Date: Thu, 6 Jun 2024 13:15:40 +0800 Subject: [PATCH] [VL] Handle try_subtract, try_multiply, try_divide (#5985) [VL] Handle try_subtract, try_multiply, try_divide. --- .../velox/VeloxSparkPlanExecApi.scala | 28 +++++---- .../ScalarFunctionsValidateSuite.scala | 24 ++++++++ .../functions/RegistrationAllFunctions.cc | 5 ++ cpp/velox/substrait/SubstraitParser.cc | 1 - docs/velox-backend-support-progress.md | 2 +- .../gluten/backendsapi/SparkPlanExecApi.scala | 12 ++-- .../expression/ExpressionConverter.scala | 58 +++++++++++++++++-- .../utils/velox/VeloxTestSettings.scala | 3 +- .../expressions/GlutenTryEvalSuite.scala | 21 +++++++ .../utils/velox/VeloxTestSettings.scala | 3 +- .../expressions/GlutenTryEvalSuite.scala | 21 +++++++ .../gluten/expression/ExpressionNames.scala | 5 +- .../sql/shims/spark34/Spark34Shims.scala | 6 ++ .../sql/shims/spark35/Spark35Shims.scala | 6 ++ 14 files changed, 171 insertions(+), 24 deletions(-) create mode 100644 gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenTryEvalSuite.scala create mode 100644 gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenTryEvalSuite.scala diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 16c11f111abc..f8af80a9b44d 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -123,42 +123,50 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { original) } - override def genTryAddTransformer( + override def genTryArithmeticTransformer( substraitExprName: String, left: ExpressionTransformer, right: ExpressionTransformer, - original: TryEval): ExpressionTransformer = { + original: TryEval, + checkArithmeticExprName: String): ExpressionTransformer = { if (SparkShimLoader.getSparkShims.withAnsiEvalMode(original.child)) { - throw new GlutenNotSupportException(s"add with ansi mode is not supported") + throw new GlutenNotSupportException( + s"${original.child.prettyName} with ansi mode is not supported") } original.child.dataType match { case LongType | IntegerType | ShortType | ByteType => - case _ => throw new GlutenNotSupportException(s"try_add is not supported") + case _ => throw new GlutenNotSupportException(s"$substraitExprName is not supported") } // Offload to velox for only IntegralTypes. GenericExpressionTransformer( substraitExprName, - Seq(GenericExpressionTransformer(ExpressionNames.TRY_ADD, Seq(left, right), original)), + Seq(GenericExpressionTransformer(checkArithmeticExprName, Seq(left, right), original)), original) } - override def genAddTransformer( + /** + * Map arithmetic expr to different functions: substraitExprName or try(checkArithmeticExprName) + * based on EvalMode. + */ + override def genArithmeticTransformer( substraitExprName: String, left: ExpressionTransformer, right: ExpressionTransformer, - original: Add): ExpressionTransformer = { + original: Expression, + checkArithmeticExprName: String): ExpressionTransformer = { if (SparkShimLoader.getSparkShims.withTryEvalMode(original)) { original.dataType match { case LongType | IntegerType | ShortType | ByteType => - case _ => throw new GlutenNotSupportException(s"try_add is not supported") + case _ => + throw new GlutenNotSupportException(s"$substraitExprName with try mode is not supported") } // Offload to velox for only IntegralTypes. GenericExpressionTransformer( ExpressionMappings.expressionsMap(classOf[TryEval]), - Seq(GenericExpressionTransformer(ExpressionNames.TRY_ADD, Seq(left, right), original)), + Seq(GenericExpressionTransformer(checkArithmeticExprName, Seq(left, right), original)), original) } else if (SparkShimLoader.getSparkShims.withAnsiEvalMode(original)) { - throw new GlutenNotSupportException(s"add with ansi mode is not supported") + throw new GlutenNotSupportException(s"$substraitExprName with ansi mode is not supported") } else { GenericExpressionTransformer(substraitExprName, Seq(left, right), original) } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index 8802c61c5f04..6df3a062331f 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -843,6 +843,30 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest { } } + testWithSpecifiedSparkVersion("try_subtract", Some("3.3")) { + runQueryAndCompare( + "select try_subtract(2147483647, cast(l_orderkey as int)), " + + "try_subtract(-2147483648, cast(l_orderkey as int)) from lineitem") { + checkGlutenOperatorMatch[ProjectExecTransformer] + } + } + + test("try_divide") { + runQueryAndCompare( + "select try_divide(cast(l_orderkey as int), 0) from lineitem", + noFallBack = false) { + _ => // Spark would always cast inputs to double for this function. + } + } + + testWithSpecifiedSparkVersion("try_multiply", Some("3.3")) { + runQueryAndCompare( + "select try_multiply(2147483647, cast(l_orderkey as int)), " + + "try_multiply(-2147483648, cast(l_orderkey as int)) from lineitem") { + checkGlutenOperatorMatch[ProjectExecTransformer] + } + } + test("test array forall") { withTempPath { path => diff --git a/cpp/velox/operators/functions/RegistrationAllFunctions.cc b/cpp/velox/operators/functions/RegistrationAllFunctions.cc index b88d781b69b2..b827690d1cdf 100644 --- a/cpp/velox/operators/functions/RegistrationAllFunctions.cc +++ b/cpp/velox/operators/functions/RegistrationAllFunctions.cc @@ -21,6 +21,7 @@ #include "operators/functions/RowFunctionWithNull.h" #include "velox/expression/SpecialFormRegistry.h" #include "velox/expression/VectorFunction.h" +#include "velox/functions/lib/CheckedArithmetic.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" @@ -67,6 +68,10 @@ void registerFunctionOverwrite() { kRowConstructorWithAllNull, std::make_unique(kRowConstructorWithAllNull)); velox::functions::sparksql::registerBitwiseFunctions("spark_"); + velox::functions::registerBinaryIntegral({"check_add"}); + velox::functions::registerBinaryIntegral({"check_subtract"}); + velox::functions::registerBinaryIntegral({"check_multiply"}); + velox::functions::registerBinaryIntegral({"check_divide"}); } } // namespace diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index f417618d8117..0880f3e3d915 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -400,7 +400,6 @@ std::unordered_map SubstraitParser::substraitVeloxFunc {"modulus", "remainder"}, {"date_format", "format_datetime"}, {"collect_set", "set_agg"}, - {"try_add", "plus"}, {"forall", "all_match"}, {"exists", "any_match"}, {"negative", "unaryminus"}, diff --git a/docs/velox-backend-support-progress.md b/docs/velox-backend-support-progress.md index 5d083c4e59ba..f39bd7016707 100644 --- a/docs/velox-backend-support-progress.md +++ b/docs/velox-backend-support-progress.md @@ -100,7 +100,7 @@ Gluten supports 199 functions. (Drag to right to see all data types) | & | bitwise_and | bitwise_and | S | | | | | | | | | | | | | | | | | | | | | * | multiply | multiply | S | ANSI OFF | | S | S | S | S | S | | | | | | | | | | | | | | + | plus | add | S | ANSI OFF | | S | S | S | S | S | | | | | | | | | | | | | -| - | minus | substract | S | ANSI OFF | | S | S | S | S | S | | | | | | | | | | | | | +| - | minus | subtract | S | ANSI OFF | | S | S | S | S | S | | | | | | | | | | | | | | / | divide | divide | S | ANSI OFF | | S | S | S | S | S | | | | | | | | | | | | | | < | lt | lessthan | S | | S | S | S | S | S | S | S | | | S | | | | | | | | | | <= | lte | lessthanorequa | S | | S | S | S | S | S | S | S | | | S | | | | | | | | | diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 78cf02f0ac24..8a086f896ba4 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -206,12 +206,13 @@ trait SparkPlanExecApi { GenericExpressionTransformer(substraitExprName, Seq(), original) } - def genTryAddTransformer( + def genTryArithmeticTransformer( substraitExprName: String, left: ExpressionTransformer, right: ExpressionTransformer, - original: TryEval): ExpressionTransformer = { - throw new GlutenNotSupportException("try_add is not supported") + original: TryEval, + checkArithmeticExprName: String): ExpressionTransformer = { + throw new GlutenNotSupportException(s"$checkArithmeticExprName is not supported") } def genTryEvalTransformer( @@ -221,11 +222,12 @@ trait SparkPlanExecApi { throw new GlutenNotSupportException("try_eval is not supported") } - def genAddTransformer( + def genArithmeticTransformer( substraitExprName: String, left: ExpressionTransformer, right: ExpressionTransformer, - original: Add): ExpressionTransformer = { + original: Expression, + checkArithmeticExprName: String): ExpressionTransformer = { GenericExpressionTransformer(substraitExprName, Seq(left, right), original) } diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index b66ec89eaf2b..9ebe44f6ca54 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -563,18 +563,68 @@ object ExpressionConverter extends SQLConfHelper with Logging { arrayTransform ) case tryEval @ TryEval(a: Add) => - BackendsApiManager.getSparkPlanExecApiInstance.genTryAddTransformer( + BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( substraitExprName, replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), - tryEval + tryEval, + ExpressionNames.CHECK_ADD + ) + case tryEval @ TryEval(a: Subtract) => + BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( + substraitExprName, + replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + tryEval, + ExpressionNames.CHECK_SUBTRACT + ) + case tryEval @ TryEval(a: Divide) => + BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( + substraitExprName, + replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + tryEval, + ExpressionNames.CHECK_DIVIDE + ) + case tryEval @ TryEval(a: Multiply) => + BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( + substraitExprName, + replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + tryEval, + ExpressionNames.CHECK_MULTIPLY ) case a: Add => - BackendsApiManager.getSparkPlanExecApiInstance.genAddTransformer( + BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( substraitExprName, replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), - a + a, + ExpressionNames.CHECK_ADD + ) + case a: Subtract => + BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( + substraitExprName, + replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + a, + ExpressionNames.CHECK_SUBTRACT + ) + case a: Multiply => + BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( + substraitExprName, + replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + a, + ExpressionNames.CHECK_MULTIPLY + ) + case a: Divide => + BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( + substraitExprName, + replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + a, + ExpressionNames.CHECK_DIVIDE ) case tryEval: TryEval => // This is a placeholder to handle try_eval(other expressions). diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index d8e3a5ecc051..bd437bbe8efb 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -20,7 +20,7 @@ import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings} import org.apache.spark.GlutenSortShuffleSuite import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, GlutenNullExpressionsSuite, GlutenPredicateSuite, GlutenRandomSuite, GlutenRegexpExpressionsSuite, GlutenSortOrderExpressionsSuite, GlutenStringExpressionsSuite} +import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, GlutenNullExpressionsSuite, GlutenPredicateSuite, GlutenRandomSuite, GlutenRegexpExpressionsSuite, GlutenSortOrderExpressionsSuite, GlutenStringExpressionsSuite, GlutenTryEvalSuite} import org.apache.spark.sql.connector.{GlutenDataSourceV2DataFrameSessionCatalogSuite, GlutenDataSourceV2DataFrameSuite, GlutenDataSourceV2FunctionSuite, GlutenDataSourceV2SQLSessionCatalogSuite, GlutenDataSourceV2SQLSuiteV1Filter, GlutenDataSourceV2SQLSuiteV2Filter, GlutenDataSourceV2Suite, GlutenDeleteFromTableSuite, GlutenDeltaBasedDeleteFromTableSuite, GlutenFileDataSourceV2FallBackSuite, GlutenGroupBasedDeleteFromTableSuite, GlutenKeyGroupedPartitioningSuite, GlutenLocalScanSuite, GlutenMetadataColumnSuite, GlutenSupportsCatalogOptionsSuite, GlutenTableCapabilityCheckSuite, GlutenWriteDistributionAndOrderingSuite} import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite, GlutenQueryCompilationErrorsSuite, GlutenQueryExecutionErrorsSuite, GlutenQueryParsingErrorsSuite} import org.apache.spark.sql.execution.{FallbackStrategiesSuite, GlutenBroadcastExchangeSuite, GlutenCoalesceShufflePartitionsSuite, GlutenExchangeSuite, GlutenLocalBroadcastExchangeSuite, GlutenReplaceHashWithSortAggSuite, GlutenReuseExchangeAndSubquerySuite, GlutenSameResultSuite, GlutenSortSuite, GlutenSQLAggregateFunctionSuite, GlutenSQLWindowFunctionSuite, GlutenTakeOrderedAndProjectSuite} @@ -141,6 +141,7 @@ class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenSortShuffleSuite] enableSuite[GlutenSortOrderExpressionsSuite] enableSuite[GlutenStringExpressionsSuite] + enableSuite[GlutenTryEvalSuite] enableSuite[VeloxAdaptiveQueryExecSuite] .includeAllGlutenTests() .includeByPrefix( diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenTryEvalSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenTryEvalSuite.scala new file mode 100644 index 000000000000..6af97677e5d8 --- /dev/null +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenTryEvalSuite.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.GlutenTestsTrait + +class GlutenTryEvalSuite extends TryEvalSuite with GlutenTestsTrait {} diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 10f7be4feaeb..af8d0deadfc8 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -20,7 +20,7 @@ import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings} import org.apache.spark.GlutenSortShuffleSuite import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, GlutenNullExpressionsSuite, GlutenPredicateSuite, GlutenRandomSuite, GlutenRegexpExpressionsSuite, GlutenSortOrderExpressionsSuite, GlutenStringExpressionsSuite} +import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, GlutenNullExpressionsSuite, GlutenPredicateSuite, GlutenRandomSuite, GlutenRegexpExpressionsSuite, GlutenSortOrderExpressionsSuite, GlutenStringExpressionsSuite, GlutenTryEvalSuite} import org.apache.spark.sql.connector._ import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite, GlutenQueryCompilationErrorsSuite, GlutenQueryExecutionErrorsSuite, GlutenQueryParsingErrorsSuite} import org.apache.spark.sql.execution._ @@ -144,6 +144,7 @@ class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenSortShuffleSuite] enableSuite[GlutenSortOrderExpressionsSuite] enableSuite[GlutenStringExpressionsSuite] + enableSuite[GlutenTryEvalSuite] enableSuite[VeloxAdaptiveQueryExecSuite] .includeAllGlutenTests() .includeByPrefix( diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenTryEvalSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenTryEvalSuite.scala new file mode 100644 index 000000000000..6af97677e5d8 --- /dev/null +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenTryEvalSuite.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.GlutenTestsTrait + +class GlutenTryEvalSuite extends TryEvalSuite with GlutenTestsTrait {} diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index be7e32fc97d6..dc98f31a395c 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -83,7 +83,10 @@ object ExpressionNames { final val IS_NAN = "isnan" final val NANVL = "nanvl" final val TRY_EVAL = "try" - final val TRY_ADD = "try_add" + final val CHECK_ADD = "check_add" + final val CHECK_SUBTRACT = "check_subtract" + final val CHECK_DIVIDE = "check_divide" + final val CHECK_MULTIPLY = "check_multiply" // SparkSQL String functions final val ASCII = "ascii" diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala index 4ab307e8568f..f2c2482949b7 100644 --- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala @@ -407,6 +407,9 @@ class Spark34Shims extends SparkShims { override def withTryEvalMode(expr: Expression): Boolean = { expr match { case a: Add => a.evalMode == EvalMode.TRY + case s: Subtract => s.evalMode == EvalMode.TRY + case d: Divide => d.evalMode == EvalMode.TRY + case m: Multiply => m.evalMode == EvalMode.TRY case _ => false } } @@ -414,6 +417,9 @@ class Spark34Shims extends SparkShims { override def withAnsiEvalMode(expr: Expression): Boolean = { expr match { case a: Add => a.evalMode == EvalMode.ANSI + case s: Subtract => s.evalMode == EvalMode.ANSI + case d: Divide => d.evalMode == EvalMode.ANSI + case m: Multiply => m.evalMode == EvalMode.ANSI case _ => false } } diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index ef1cea865d49..e0835c3069d2 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -436,6 +436,9 @@ class Spark35Shims extends SparkShims { override def withTryEvalMode(expr: Expression): Boolean = { expr match { case a: Add => a.evalMode == EvalMode.TRY + case s: Subtract => s.evalMode == EvalMode.TRY + case d: Divide => d.evalMode == EvalMode.TRY + case m: Multiply => m.evalMode == EvalMode.TRY case _ => false } } @@ -443,6 +446,9 @@ class Spark35Shims extends SparkShims { override def withAnsiEvalMode(expr: Expression): Boolean = { expr match { case a: Add => a.evalMode == EvalMode.ANSI + case s: Subtract => s.evalMode == EvalMode.ANSI + case d: Divide => d.evalMode == EvalMode.ANSI + case m: Multiply => m.evalMode == EvalMode.ANSI case _ => false } }