From a1e0367e9d129cc5f9ca30f596819d1546a9a1fd Mon Sep 17 00:00:00 2001 From: WangGuangxin Date: Tue, 19 Mar 2024 10:51:02 +0800 Subject: [PATCH] [GLUTEN-5003][VL] Fix Null literal fallback (#5004) --- .../glutenproject/execution/VeloxLiteralSuite.scala | 6 +++--- cpp/velox/substrait/SubstraitParser.cc | 5 +++++ cpp/velox/substrait/SubstraitToVeloxExpr.cc | 11 +++++------ cpp/velox/substrait/SubstraitToVeloxExpr.h | 1 - 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala index d625b80783f0..557681558f56 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala @@ -115,6 +115,9 @@ class VeloxLiteralSuite extends VeloxWholeStageTransformerSuite { validateOffloadResult("SELECT struct('Spark', cast(null as int))") validateOffloadResult("SELECT struct(cast(null as decimal))") validateOffloadResult("SELECT map('b', 'a', 'e', null)") + validateOffloadResult("SELECT array(null)") + validateOffloadResult("SELECT array(cast(null as int))") + validateOffloadResult("SELECT map(1, null)") } test("Scalar Type Literal") { @@ -132,9 +135,6 @@ class VeloxLiteralSuite extends VeloxWholeStageTransformerSuite { } test("Literal Fallback") { - validateFallbackResult("SELECT array(null)") - validateFallbackResult("SELECT array(cast(null as int))") - validateFallbackResult("SELECT map(1, null)") validateFallbackResult("SELECT struct(cast(null as struct))") } } diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index 35f130076aff..ce6a532ef7e7 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -294,6 +294,11 @@ T SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& /* li VELOX_NYI(); } +template <> +std::shared_ptr gluten::SubstraitParser::getLiteralValue(const substrait::Expression_Literal& literal) { + return nullptr; +} + template <> facebook::velox::UnknownValue gluten::SubstraitParser::getLiteralValue(const substrait::Expression_Literal& literal) { return UnknownValue(); diff --git a/cpp/velox/substrait/SubstraitToVeloxExpr.cc b/cpp/velox/substrait/SubstraitToVeloxExpr.cc index 4071c1b0111b..fd68be16e2bc 100644 --- a/cpp/velox/substrait/SubstraitToVeloxExpr.cc +++ b/cpp/velox/substrait/SubstraitToVeloxExpr.cc @@ -395,7 +395,7 @@ ArrayVectorPtr SubstraitVeloxExprConverter::literalsToArrayVector(const ::substr VELOX_CHECK_GT(childSize, 0, "there should be at least 1 value in list literal."); auto childLiteral = literal.list().values(0); auto elementAtFunc = [&](vector_size_t idx) { return literal.list().values(idx); }; - auto childVector = literalsToVector(childLiteral, childSize, literal, elementAtFunc); + auto childVector = literalsToVector(childLiteral, childSize, elementAtFunc); return makeArrayVector(childVector); } @@ -406,22 +406,21 @@ MapVectorPtr SubstraitVeloxExprConverter::literalsToMapVector(const ::substrait: auto& valueLiteral = literal.map().key_values(0).value(); auto keyAtFunc = [&](vector_size_t idx) { return literal.map().key_values(idx).key(); }; auto valueAtFunc = [&](vector_size_t idx) { return literal.map().key_values(idx).value(); }; - auto keyVector = literalsToVector(keyLiteral, childSize, literal, keyAtFunc); - auto valueVector = literalsToVector(valueLiteral, childSize, literal, valueAtFunc); + auto keyVector = literalsToVector(keyLiteral, childSize, keyAtFunc); + auto valueVector = literalsToVector(valueLiteral, childSize, valueAtFunc); return makeMapVector(keyVector, valueVector); } VectorPtr SubstraitVeloxExprConverter::literalsToVector( const ::substrait::Expression::Literal& childLiteral, vector_size_t childSize, - const ::substrait::Expression::Literal& literal, std::function<::substrait::Expression::Literal(vector_size_t /* idx */)> elementAtFunc) { auto childTypeCase = childLiteral.literal_type_case(); switch (childTypeCase) { case ::substrait::Expression_Literal::LiteralTypeCase::kNull: { - auto veloxType = SubstraitParser::parseType(literal.null()); + auto veloxType = SubstraitParser::parseType(childLiteral.null()); auto kind = veloxType->kind(); - return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(constructFlatVector, kind, elementAtFunc, childSize, veloxType, pool_); + return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH_ALL(constructFlatVector, kind, elementAtFunc, childSize, veloxType, pool_); } case ::substrait::Expression_Literal::LiteralTypeCase::kIntervalDayToSecond: return constructFlatVector(elementAtFunc, childSize, INTERVAL_DAY_TIME(), pool_); diff --git a/cpp/velox/substrait/SubstraitToVeloxExpr.h b/cpp/velox/substrait/SubstraitToVeloxExpr.h index 5699093b80c4..619d8c31cde0 100644 --- a/cpp/velox/substrait/SubstraitToVeloxExpr.h +++ b/cpp/velox/substrait/SubstraitToVeloxExpr.h @@ -93,7 +93,6 @@ class SubstraitVeloxExprConverter { VectorPtr literalsToVector( const ::substrait::Expression::Literal& childLiteral, vector_size_t childSize, - const ::substrait::Expression::Literal& literal, std::function<::substrait::Expression::Literal(vector_size_t /* idx */)> elementAtFunc); RowVectorPtr literalsToRowVector(const ::substrait::Expression::Literal& structLiteral);