Skip to content

Commit

Permalink
[VL] Handle try_subtract, try_multiply, try_divide (apache#5985)
Browse files Browse the repository at this point in the history
[VL] Handle try_subtract, try_multiply, try_divide.
  • Loading branch information
zhli1142015 authored Jun 6, 2024
1 parent 95dcdbd commit a2e96f7
Show file tree
Hide file tree
Showing 14 changed files with 171 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,42 +123,50 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
original)
}

override def genTryAddTransformer(
override def genTryArithmeticTransformer(
substraitExprName: String,
left: ExpressionTransformer,
right: ExpressionTransformer,
original: TryEval): ExpressionTransformer = {
original: TryEval,
checkArithmeticExprName: String): ExpressionTransformer = {
if (SparkShimLoader.getSparkShims.withAnsiEvalMode(original.child)) {
throw new GlutenNotSupportException(s"add with ansi mode is not supported")
throw new GlutenNotSupportException(
s"${original.child.prettyName} with ansi mode is not supported")
}
original.child.dataType match {
case LongType | IntegerType | ShortType | ByteType =>
case _ => throw new GlutenNotSupportException(s"try_add is not supported")
case _ => throw new GlutenNotSupportException(s"$substraitExprName is not supported")
}
// Offload to velox for only IntegralTypes.
GenericExpressionTransformer(
substraitExprName,
Seq(GenericExpressionTransformer(ExpressionNames.TRY_ADD, Seq(left, right), original)),
Seq(GenericExpressionTransformer(checkArithmeticExprName, Seq(left, right), original)),
original)
}

override def genAddTransformer(
/**
* Map arithmetic expr to different functions: substraitExprName or try(checkArithmeticExprName)
* based on EvalMode.
*/
override def genArithmeticTransformer(
substraitExprName: String,
left: ExpressionTransformer,
right: ExpressionTransformer,
original: Add): ExpressionTransformer = {
original: Expression,
checkArithmeticExprName: String): ExpressionTransformer = {
if (SparkShimLoader.getSparkShims.withTryEvalMode(original)) {
original.dataType match {
case LongType | IntegerType | ShortType | ByteType =>
case _ => throw new GlutenNotSupportException(s"try_add is not supported")
case _ =>
throw new GlutenNotSupportException(s"$substraitExprName with try mode is not supported")
}
// Offload to velox for only IntegralTypes.
GenericExpressionTransformer(
ExpressionMappings.expressionsMap(classOf[TryEval]),
Seq(GenericExpressionTransformer(ExpressionNames.TRY_ADD, Seq(left, right), original)),
Seq(GenericExpressionTransformer(checkArithmeticExprName, Seq(left, right), original)),
original)
} else if (SparkShimLoader.getSparkShims.withAnsiEvalMode(original)) {
throw new GlutenNotSupportException(s"add with ansi mode is not supported")
throw new GlutenNotSupportException(s"$substraitExprName with ansi mode is not supported")
} else {
GenericExpressionTransformer(substraitExprName, Seq(left, right), original)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,30 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest {
}
}

testWithSpecifiedSparkVersion("try_subtract", Some("3.3")) {
runQueryAndCompare(
"select try_subtract(2147483647, cast(l_orderkey as int)), " +
"try_subtract(-2147483648, cast(l_orderkey as int)) from lineitem") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}

test("try_divide") {
runQueryAndCompare(
"select try_divide(cast(l_orderkey as int), 0) from lineitem",
noFallBack = false) {
_ => // Spark would always cast inputs to double for this function.
}
}

testWithSpecifiedSparkVersion("try_multiply", Some("3.3")) {
runQueryAndCompare(
"select try_multiply(2147483647, cast(l_orderkey as int)), " +
"try_multiply(-2147483648, cast(l_orderkey as int)) from lineitem") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}

test("test array forall") {
withTempPath {
path =>
Expand Down
5 changes: 5 additions & 0 deletions cpp/velox/operators/functions/RegistrationAllFunctions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "operators/functions/RowFunctionWithNull.h"
#include "velox/expression/SpecialFormRegistry.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/CheckedArithmetic.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
Expand Down Expand Up @@ -67,6 +68,10 @@ void registerFunctionOverwrite() {
kRowConstructorWithAllNull,
std::make_unique<RowConstructorWithNullCallToSpecialForm>(kRowConstructorWithAllNull));
velox::functions::sparksql::registerBitwiseFunctions("spark_");
velox::functions::registerBinaryIntegral<velox::functions::CheckedPlusFunction>({"check_add"});
velox::functions::registerBinaryIntegral<velox::functions::CheckedMinusFunction>({"check_subtract"});
velox::functions::registerBinaryIntegral<velox::functions::CheckedMultiplyFunction>({"check_multiply"});
velox::functions::registerBinaryIntegral<velox::functions::CheckedDivideFunction>({"check_divide"});
}
} // namespace

Expand Down
1 change: 0 additions & 1 deletion cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ std::unordered_map<std::string, std::string> SubstraitParser::substraitVeloxFunc
{"modulus", "remainder"},
{"date_format", "format_datetime"},
{"collect_set", "set_agg"},
{"try_add", "plus"},
{"forall", "all_match"},
{"exists", "any_match"},
{"negative", "unaryminus"},
Expand Down
2 changes: 1 addition & 1 deletion docs/velox-backend-support-progress.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Gluten supports 199 functions. (Drag to right to see all data types)
| & | bitwise_and | bitwise_and | S | | | | | | | | | | | | | | | | | | | |
| * | multiply | multiply | S | ANSI OFF | | S | S | S | S | S | | | | | | | | | | | | |
| + | plus | add | S | ANSI OFF | | S | S | S | S | S | | | | | | | | | | | | |
| - | minus | substract | S | ANSI OFF | | S | S | S | S | S | | | | | | | | | | | | |
| - | minus | subtract | S | ANSI OFF | | S | S | S | S | S | | | | | | | | | | | | |
| / | divide | divide | S | ANSI OFF | | S | S | S | S | S | | | | | | | | | | | | |
| < | lt | lessthan | S | | S | S | S | S | S | S | S | | | S | | | | | | | | |
| <= | lte | lessthanorequa | S | | S | S | S | S | S | S | S | | | S | | | | | | | | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,13 @@ trait SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, Seq(), original)
}

def genTryAddTransformer(
def genTryArithmeticTransformer(
substraitExprName: String,
left: ExpressionTransformer,
right: ExpressionTransformer,
original: TryEval): ExpressionTransformer = {
throw new GlutenNotSupportException("try_add is not supported")
original: TryEval,
checkArithmeticExprName: String): ExpressionTransformer = {
throw new GlutenNotSupportException(s"$checkArithmeticExprName is not supported")
}

def genTryEvalTransformer(
Expand All @@ -221,11 +222,12 @@ trait SparkPlanExecApi {
throw new GlutenNotSupportException("try_eval is not supported")
}

def genAddTransformer(
def genArithmeticTransformer(
substraitExprName: String,
left: ExpressionTransformer,
right: ExpressionTransformer,
original: Add): ExpressionTransformer = {
original: Expression,
checkArithmeticExprName: String): ExpressionTransformer = {
GenericExpressionTransformer(substraitExprName, Seq(left, right), original)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,18 +563,68 @@ object ExpressionConverter extends SQLConfHelper with Logging {
arrayTransform
)
case tryEval @ TryEval(a: Add) =>
BackendsApiManager.getSparkPlanExecApiInstance.genTryAddTransformer(
BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap),
tryEval
tryEval,
ExpressionNames.CHECK_ADD
)
case tryEval @ TryEval(a: Subtract) =>
BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap),
tryEval,
ExpressionNames.CHECK_SUBTRACT
)
case tryEval @ TryEval(a: Divide) =>
BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap),
tryEval,
ExpressionNames.CHECK_DIVIDE
)
case tryEval @ TryEval(a: Multiply) =>
BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap),
tryEval,
ExpressionNames.CHECK_MULTIPLY
)
case a: Add =>
BackendsApiManager.getSparkPlanExecApiInstance.genAddTransformer(
BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap),
a
a,
ExpressionNames.CHECK_ADD
)
case a: Subtract =>
BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap),
a,
ExpressionNames.CHECK_SUBTRACT
)
case a: Multiply =>
BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap),
a,
ExpressionNames.CHECK_MULTIPLY
)
case a: Divide =>
BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap),
a,
ExpressionNames.CHECK_DIVIDE
)
case tryEval: TryEval =>
// This is a placeholder to handle try_eval(other expressions).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings}

import org.apache.spark.GlutenSortShuffleSuite
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, GlutenNullExpressionsSuite, GlutenPredicateSuite, GlutenRandomSuite, GlutenRegexpExpressionsSuite, GlutenSortOrderExpressionsSuite, GlutenStringExpressionsSuite}
import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, GlutenNullExpressionsSuite, GlutenPredicateSuite, GlutenRandomSuite, GlutenRegexpExpressionsSuite, GlutenSortOrderExpressionsSuite, GlutenStringExpressionsSuite, GlutenTryEvalSuite}
import org.apache.spark.sql.connector.{GlutenDataSourceV2DataFrameSessionCatalogSuite, GlutenDataSourceV2DataFrameSuite, GlutenDataSourceV2FunctionSuite, GlutenDataSourceV2SQLSessionCatalogSuite, GlutenDataSourceV2SQLSuiteV1Filter, GlutenDataSourceV2SQLSuiteV2Filter, GlutenDataSourceV2Suite, GlutenDeleteFromTableSuite, GlutenDeltaBasedDeleteFromTableSuite, GlutenFileDataSourceV2FallBackSuite, GlutenGroupBasedDeleteFromTableSuite, GlutenKeyGroupedPartitioningSuite, GlutenLocalScanSuite, GlutenMetadataColumnSuite, GlutenSupportsCatalogOptionsSuite, GlutenTableCapabilityCheckSuite, GlutenWriteDistributionAndOrderingSuite}
import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite, GlutenQueryCompilationErrorsSuite, GlutenQueryExecutionErrorsSuite, GlutenQueryParsingErrorsSuite}
import org.apache.spark.sql.execution.{FallbackStrategiesSuite, GlutenBroadcastExchangeSuite, GlutenCoalesceShufflePartitionsSuite, GlutenExchangeSuite, GlutenLocalBroadcastExchangeSuite, GlutenReplaceHashWithSortAggSuite, GlutenReuseExchangeAndSubquerySuite, GlutenSameResultSuite, GlutenSortSuite, GlutenSQLAggregateFunctionSuite, GlutenSQLWindowFunctionSuite, GlutenTakeOrderedAndProjectSuite}
Expand Down Expand Up @@ -141,6 +141,7 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenSortShuffleSuite]
enableSuite[GlutenSortOrderExpressionsSuite]
enableSuite[GlutenStringExpressionsSuite]
enableSuite[GlutenTryEvalSuite]
enableSuite[VeloxAdaptiveQueryExecSuite]
.includeAllGlutenTests()
.includeByPrefix(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.GlutenTestsTrait

class GlutenTryEvalSuite extends TryEvalSuite with GlutenTestsTrait {}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings}

import org.apache.spark.GlutenSortShuffleSuite
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, GlutenNullExpressionsSuite, GlutenPredicateSuite, GlutenRandomSuite, GlutenRegexpExpressionsSuite, GlutenSortOrderExpressionsSuite, GlutenStringExpressionsSuite}
import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, GlutenNullExpressionsSuite, GlutenPredicateSuite, GlutenRandomSuite, GlutenRegexpExpressionsSuite, GlutenSortOrderExpressionsSuite, GlutenStringExpressionsSuite, GlutenTryEvalSuite}
import org.apache.spark.sql.connector._
import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite, GlutenQueryCompilationErrorsSuite, GlutenQueryExecutionErrorsSuite, GlutenQueryParsingErrorsSuite}
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -144,6 +144,7 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenSortShuffleSuite]
enableSuite[GlutenSortOrderExpressionsSuite]
enableSuite[GlutenStringExpressionsSuite]
enableSuite[GlutenTryEvalSuite]
enableSuite[VeloxAdaptiveQueryExecSuite]
.includeAllGlutenTests()
.includeByPrefix(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.GlutenTestsTrait

class GlutenTryEvalSuite extends TryEvalSuite with GlutenTestsTrait {}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ object ExpressionNames {
final val IS_NAN = "isnan"
final val NANVL = "nanvl"
final val TRY_EVAL = "try"
final val TRY_ADD = "try_add"
final val CHECK_ADD = "check_add"
final val CHECK_SUBTRACT = "check_subtract"
final val CHECK_DIVIDE = "check_divide"
final val CHECK_MULTIPLY = "check_multiply"

// SparkSQL String functions
final val ASCII = "ascii"
Expand Down
Loading

0 comments on commit a2e96f7

Please sign in to comment.