diff --git a/velox/docs/functions/spark/map.rst b/velox/docs/functions/spark/map.rst index 2f74b87c886a2..d114fc223dc83 100644 --- a/velox/docs/functions/spark/map.rst +++ b/velox/docs/functions/spark/map.rst @@ -35,11 +35,13 @@ Map Functions .. spark:function:: map_from_entries(array(struct(K,V))) -> map(K,V) - Returns a map created from the given array of entries. Exceptions will be thrown for duplicated keys or key is null or contains null. - If null entry exists in the array, return null for this whole array. :: + Returns a map created from the given array of entries. Exception is thrown if the entries of structs contain duplicate key, + or one entry has a null key. Returns null if one entry is null. :: SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'null'))); -- {1 -> 'a', 2 -> 'null'} SELECT map_from_entries(array(struct(1, 'a'), null)); -- {null} + SELECT map_from_entries(array(struct(null, 'a'))); -- "map key cannot be null" + SELECT map_from_entries(array(struct(1, 'a'), struct(1, 'b'))); -- "Duplicate map keys (1) are not allowed" .. spark:function:: map_keys(x(K,V)) -> array(K) diff --git a/velox/functions/lib/MapFromEntries.cpp b/velox/functions/lib/MapFromEntries.cpp index cf712cf4d5264..ec373d2610b0a 100644 --- a/velox/functions/lib/MapFromEntries.cpp +++ b/velox/functions/lib/MapFromEntries.cpp @@ -32,10 +32,9 @@ static const char* kErrorMessageEntryNotNull = "map entry cannot be null"; class MapFromEntriesFunction : public exec::VectorFunction { public: - // If throwOnNull is true, will return null if input array is null or has - // null entries (Spark's behavior), instead of throwing exceptions (Presto's - // behavior). - MapFromEntriesFunction(const bool throwOnNull) : throwOnNull_(throwOnNull) {} + // @param throwOnNull If true, throws exception when input array is null or + // contains null entry. Otherwise, returns null. + MapFromEntriesFunction(bool throwOnNull) : throwOnNull_(throwOnNull) {} void apply( const SelectivityVector& rows, std::vector& args, @@ -98,6 +97,8 @@ class MapFromEntriesFunction : public exec::VectorFunction { exec::LocalDecodedVector decodedRowVector(context); decodedRowVector.get()->decode(*inputValueVector); if (inputValueVector->typeKind() == TypeKind::UNKNOWN) { + // For presto, if the input array(unknown) then all rows should have + // errors. if (throwOnNull_) { try { VELOX_USER_FAIL(kErrorMessageEntryNotNull); @@ -145,16 +146,13 @@ class MapFromEntriesFunction : public exec::VectorFunction { // Check nulls in the top level row vector. const bool isMapEntryNull = decodedRowVector->isNullAt(offset + i); if (isMapEntryNull) { + // The map vector needs to be valid because its consumed by + // checkDuplicateKeys before try sets invalid rows to null. + resetSize(row); if (!throwOnNull_) { bits::setNull(mutableNulls, row); - resetSize(row); break; } - // Presto: Set the sizes to 0 so that the final map vector generated - // is valid in case we are inside a try. The map vector needs to be - // valid because its consumed by checkDuplicateKeys before try - // sets invalid rows to null. - resetSize(row); VELOX_USER_FAIL(kErrorMessageEntryNotNull); } @@ -250,13 +248,11 @@ class MapFromEntriesFunction : public exec::VectorFunction { return mapVector; } - bool throwOnNull_; + const bool throwOnNull_; }; } // namespace -void registerMapFromEntriesFunction( - const std::string& name, - const bool throwOnNull) { +void registerMapFromEntriesFunction(const std::string& name, bool throwOnNull) { exec::registerVectorFunction( name, MapFromEntriesFunction::signatures(), diff --git a/velox/functions/lib/MapFromEntries.h b/velox/functions/lib/MapFromEntries.h index f247b4cf450f3..98d8aa5f92585 100644 --- a/velox/functions/lib/MapFromEntries.h +++ b/velox/functions/lib/MapFromEntries.h @@ -20,8 +20,8 @@ namespace facebook::velox::functions { -void registerMapFromEntriesFunction( - const std::string& name, - const bool throwForNull); +/// @param throwOnNull If true, throws exception when input array is null or +/// contains null entry. Otherwise, returns null. +void registerMapFromEntriesFunction(const std::string& name, bool throwForNull); } // namespace facebook::velox::functions diff --git a/velox/functions/sparksql/tests/MapFromEntriesTest.cpp b/velox/functions/sparksql/tests/MapFromEntriesTest.cpp index 9c8226f8e9025..80cc4c30f94c0 100644 --- a/velox/functions/sparksql/tests/MapFromEntriesTest.cpp +++ b/velox/functions/sparksql/tests/MapFromEntriesTest.cpp @@ -49,11 +49,10 @@ class MapFromEntriesTest : public SparkFunctionBaseTest { } void verifyMapFromEntries( - const std::vector& input, - const VectorPtr& expected, - const std::string& funcArg = "c0") { - const std::string expr = fmt::format("map_from_entries({})", funcArg); - auto result = evaluate(expr, makeRowVector(input)); + const VectorPtr& input, + const VectorPtr& expected) { + const std::string expr = fmt::format("map_from_entries({})", "c0"); + auto result = evaluate(expr, makeRowVector({input})); assertEqualVectors(expected, result); } }; @@ -70,7 +69,7 @@ TEST_F(MapFromEntriesTest, nullMapEntries) { auto input = makeArrayOfRowVector(data, rowType); auto expected = makeNullableMapVector({std::nullopt, O({{1, 11}})}); - verifyMapFromEntries({input}, expected, "c0"); + verifyMapFromEntries(input, expected); } { // Create array(row(a,b)) where a, b sizes are 0 because all row(a, b) @@ -87,7 +86,7 @@ TEST_F(MapFromEntriesTest, nullMapEntries) { auto expected = makeNullableMapVector({std::nullopt, std::nullopt}); - verifyMapFromEntries({input}, expected, "c0"); + verifyMapFromEntries(input, expected); } }