From 1a394a7600140cc699ae88a6edb6f3263f7ef06c Mon Sep 17 00:00:00 2001 From: yan ma Date: Thu, 16 May 2024 21:47:16 +0800 Subject: [PATCH] address comments --- velox/docs/functions/spark/map.rst | 17 +++++----- velox/functions/lib/MapFromEntries.cpp | 34 +++++++++---------- velox/functions/lib/MapFromEntries.h | 6 ++-- .../registration/MapFunctionsRegistration.cpp | 2 +- velox/functions/sparksql/Register.cpp | 2 +- 5 files changed, 29 insertions(+), 32 deletions(-) diff --git a/velox/docs/functions/spark/map.rst b/velox/docs/functions/spark/map.rst index 175c19298766c..f02d6da44e046 100644 --- a/velox/docs/functions/spark/map.rst +++ b/velox/docs/functions/spark/map.rst @@ -33,6 +33,14 @@ Map Functions SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')); -- {1.0 -> 2, 3.0 -> 4} +.. spark:function:: map_from_entries(array(struct(K,V))) -> map(K,V) + + Returns a map created from the given array of entries. Exceptions will be thrown for duplicated keys or key is null or contains null. + If null entry exists in the array, return null for this whole array. :: + + SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'null'))); -- {1 -> 'a', 2 -> 'null'} + SELECT map_from_entries(array(struct(1, 'a'), null)); -- {null} + .. spark::function:: map_keys(x(K,V)) -> array(K) Returns all the keys in the map ``x``. @@ -41,15 +49,6 @@ Map Functions Returns all the values in the map ``x``. -.. spark:function:: map_from_entries(array(struct(K,V))) -> map(K,V) - - Returns a map created from the given array of entries. Exceptions will be thrown if key is null or contains null. - If null entry exists in the array, return null for this whole array.:: - - SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'null'))); -- {1 -> 'a', 2 -> 'null'} - SELECT map_from_entries(array(struct(1, 'a'), null)); -- {null} - - .. spark:function:: size(map(K,V)) -> bigint :noindex: diff --git a/velox/functions/lib/MapFromEntries.cpp b/velox/functions/lib/MapFromEntries.cpp index 445aaa10f8c17..cf712cf4d5264 100644 --- a/velox/functions/lib/MapFromEntries.cpp +++ b/velox/functions/lib/MapFromEntries.cpp @@ -30,12 +30,12 @@ static const char* kIndeterminateKeyErrorMessage = "map key cannot be indeterminate"; static const char* kErrorMessageEntryNotNull = "map entry cannot be null"; -/// @tparam throwForNull If true, will return null if input array is null or has -/// null entries (Spark's behavior), instead of throwing exceptions (Presto's -/// behavior). -template class MapFromEntriesFunction : public exec::VectorFunction { public: + // If throwOnNull is true, will return null if input array is null or has + // null entries (Spark's behavior), instead of throwing exceptions (Presto's + // behavior). + MapFromEntriesFunction(const bool throwOnNull) : throwOnNull_(throwOnNull) {} void apply( const SelectivityVector& rows, std::vector& args, @@ -98,7 +98,7 @@ class MapFromEntriesFunction : public exec::VectorFunction { exec::LocalDecodedVector decodedRowVector(context); decodedRowVector.get()->decode(*inputValueVector); if (inputValueVector->typeKind() == TypeKind::UNKNOWN) { - if constexpr (throwForNull) { + if (throwOnNull_) { try { VELOX_USER_FAIL(kErrorMessageEntryNotNull); } catch (...) { @@ -145,7 +145,7 @@ class MapFromEntriesFunction : public exec::VectorFunction { // Check nulls in the top level row vector. const bool isMapEntryNull = decodedRowVector->isNullAt(offset + i); if (isMapEntryNull) { - if constexpr (!throwForNull) { + if (!throwOnNull_) { bits::setNull(mutableNulls, row); resetSize(row); break; @@ -227,8 +227,9 @@ class MapFromEntriesFunction : public exec::VectorFunction { } // For Presto, need construct map vector based on input nulls for possible - // outer expression like try(). For Spark, use the updated nulls unless it's empty. - if constexpr (throwForNull) { + // outer expression like try(). For Spark, use the updated nulls unless it's + // empty. + if (throwOnNull_) { nulls = inputArray->nulls(); } else { if (decodedRowVector->size() == 0) { @@ -248,20 +249,17 @@ class MapFromEntriesFunction : public exec::VectorFunction { checkDuplicateKeys(mapVector, *remianingRows, context); return mapVector; } + + bool throwOnNull_; }; } // namespace -void registerMapFromEntriesThrowForNullFunction(const std::string& name) { - exec::registerVectorFunction( - name, - MapFromEntriesFunction::signatures(), - std::make_unique>()); -} - -void registerMapFromEntriesFunction(const std::string& name) { +void registerMapFromEntriesFunction( + const std::string& name, + const bool throwOnNull) { exec::registerVectorFunction( name, - MapFromEntriesFunction::signatures(), - std::make_unique>()); + MapFromEntriesFunction::signatures(), + std::make_unique(throwOnNull)); } } // namespace facebook::velox::functions diff --git a/velox/functions/lib/MapFromEntries.h b/velox/functions/lib/MapFromEntries.h index eeb3d141075e4..f247b4cf450f3 100644 --- a/velox/functions/lib/MapFromEntries.h +++ b/velox/functions/lib/MapFromEntries.h @@ -20,8 +20,8 @@ namespace facebook::velox::functions { -void registerMapFromEntriesThrowForNullFunction(const std::string& name); - -void registerMapFromEntriesFunction(const std::string& name); +void registerMapFromEntriesFunction( + const std::string& name, + const bool throwForNull); } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp index 73c978a0afa6c..4d974393fcfbb 100644 --- a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp @@ -69,7 +69,7 @@ void registerMapFunctions(const std::string& prefix) { udf_transform_values, prefix + "transform_values"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map, prefix + "map"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map_entries, prefix + "map_entries"); - registerMapFromEntriesThrowForNullFunction(prefix + "map_from_entries"); + registerMapFromEntriesFunction(prefix + "map_from_entries", true); VELOX_REGISTER_VECTOR_FUNCTION(udf_map_keys, prefix + "map_keys"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map_values, prefix + "map_values"); diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 96bb15bff1a06..fe320e4687d7a 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -112,7 +112,7 @@ static void workAroundRegistrationMacro(const std::string& prefix) { VELOX_REGISTER_VECTOR_FUNCTION( udf_map_allow_duplicates, prefix + "map_from_arrays"); - registerMapFromEntriesFunction(prefix + "map_from_entries"); + registerMapFromEntriesFunction(prefix + "map_from_entries", false); VELOX_REGISTER_VECTOR_FUNCTION( udf_concat_row, exec::RowConstructorCallToSpecialForm::kRowConstructor); // String functions.