diff --git a/velox/functions/lib/MapFromEntries.cpp b/velox/functions/lib/MapFromEntries.cpp index 2263043910530..e402d079a1008 100644 --- a/velox/functions/lib/MapFromEntries.cpp +++ b/velox/functions/lib/MapFromEntries.cpp @@ -30,10 +30,10 @@ static const char* kIndeterminateKeyErrorMessage = "map key cannot be indeterminate"; static const char* kErrorMessageEntryNotNull = "map entry cannot be null"; -// allowNullEle If true, will return null if map has -// an entry with null as key or map is null (Spark's behavior) -// instead of throw execeptions(Presto's behavior) -template +/// @tparam throwForNull If true, will return null if input array is null or has +/// null entry (Spark's behavior), instead of throw execeptions(Presto's +/// behavior). +template class MapFromEntriesFunction : public exec::VectorFunction { public: void apply( @@ -98,7 +98,7 @@ class MapFromEntriesFunction : public exec::VectorFunction { exec::LocalDecodedVector decodedRowVector(context); decodedRowVector.get()->decode(*inputValueVector); if (inputValueVector->typeKind() == TypeKind::UNKNOWN) { - if (!allowNullEle) { + if constexpr (throwForNull) { try { VELOX_USER_FAIL(kErrorMessageEntryNotNull); } catch (...) { @@ -144,12 +144,13 @@ class MapFromEntriesFunction : public exec::VectorFunction { for (auto i = 0; i < size; ++i) { // Check nulls in the top level row vector. const bool isMapEntryNull = decodedRowVector->isNullAt(offset + i); - if (isMapEntryNull && allowNullEle) { - // Spark: For nulls in the top level row vector, return null. - bits::setNull(mutableNulls, row); - resetSize(row); - break; - } else if (isMapEntryNull) { + if (isMapEntryNull) { + if constexpr (!throwForNull) { + // Spark: For nulls in the top level row vector, return null. + bits::setNull(mutableNulls, row); + resetSize(row); + break; + } // Presto: Set the sizes to 0 so that the final map vector generated // is valid in case we are inside a try. The map vector needs to be // valid because its consumed by checkDuplicateKeys before try @@ -228,34 +229,45 @@ class MapFromEntriesFunction : public exec::VectorFunction { // For Presto, need construct map vector based on input nulls for possible // outer expression like try(). For Spark, use the updated nulls. - auto mapVetorNulls = allowNullEle ? nulls : inputArray->nulls(); - auto mapVector = std::make_shared( - context.pool(), - outputType, - mapVetorNulls, - rows.end(), - inputArray->offsets(), - sizes, - wrappedKeys, - wrappedValues); - + std::shared_ptr mapVector; + if constexpr (throwForNull) { + mapVector = std::make_shared( + context.pool(), + outputType, + inputArray->nulls(), + rows.end(), + inputArray->offsets(), + sizes, + wrappedKeys, + wrappedValues); + } else { + mapVector = std::make_shared( + context.pool(), + outputType, + nulls, + rows.end(), + inputArray->offsets(), + sizes, + wrappedKeys, + wrappedValues); + } checkDuplicateKeys(mapVector, *remianingRows, context); return mapVector; } }; } // namespace -void registerMapFromEntriesFunction(const std::string& name) { +void registerMapFromEntriesThrowForNullFunction(const std::string& name) { exec::registerVectorFunction( name, - MapFromEntriesFunction::signatures(), - std::make_unique>()); + MapFromEntriesFunction::signatures(), + std::make_unique>()); } -void registerMapFromEntriesAllowNullEleFunction(const std::string& name) { +void registerMapFromEntriesFunction(const std::string& name) { exec::registerVectorFunction( name, - MapFromEntriesFunction::signatures(), - std::make_unique>()); + MapFromEntriesFunction::signatures(), + std::make_unique>()); } } // namespace facebook::velox::functions diff --git a/velox/functions/lib/MapFromEntries.h b/velox/functions/lib/MapFromEntries.h index 02860d62e4f0e..eeb3d141075e4 100644 --- a/velox/functions/lib/MapFromEntries.h +++ b/velox/functions/lib/MapFromEntries.h @@ -20,8 +20,8 @@ namespace facebook::velox::functions { -void registerMapFromEntriesFunction(const std::string& name); +void registerMapFromEntriesThrowForNullFunction(const std::string& name); -void registerMapFromEntriesAllowNullEleFunction(const std::string& name); +void registerMapFromEntriesFunction(const std::string& name); } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp index 6c811eafb1e02..10f606ad660ac 100644 --- a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp @@ -42,7 +42,7 @@ void registerMapFunctions(const std::string& prefix) { udf_transform_values, prefix + "transform_values"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map, prefix + "map"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map_entries, prefix + "map_entries"); - registerMapFromEntriesFunction(prefix + "map_from_entries"); + registerMapFromEntriesThrowForNullFunction(prefix + "map_from_entries"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map_keys, prefix + "map_keys"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map_values, prefix + "map_values"); diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 0b5772905c992..0858c2bd15adc 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -73,7 +73,7 @@ static void workAroundRegistrationMacro(const std::string& prefix) { VELOX_REGISTER_VECTOR_FUNCTION( udf_map_allow_duplicates, prefix + "map_from_arrays"); - registerMapFromEntriesAllowNullEleFunction(prefix + "map_from_entries"); + registerMapFromEntriesFunction(prefix + "map_from_entries"); VELOX_REGISTER_VECTOR_FUNCTION( udf_concat_row, exec::RowConstructorCallToSpecialForm::kRowConstructor); // String functions.