diff --git a/velox/docs/functions/spark/map.rst b/velox/docs/functions/spark/map.rst index 48a71952d45c2..175c19298766c 100644 --- a/velox/docs/functions/spark/map.rst +++ b/velox/docs/functions/spark/map.rst @@ -43,7 +43,7 @@ Map Functions .. spark:function:: map_from_entries(array(struct(K,V))) -> map(K,V) - Returns a map created from the given array of entries. Keys are not allowed to be null or to contain nulls. + Returns a map created from the given array of entries. Exceptions will be thrown if key is null or contains null. If null entry exists in the array, return null for this whole array.:: SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'null'))); -- {1 -> 'a', 2 -> 'null'} diff --git a/velox/functions/lib/MapFromEntries.cpp b/velox/functions/lib/MapFromEntries.cpp index 29e06eab942f3..445aaa10f8c17 100644 --- a/velox/functions/lib/MapFromEntries.cpp +++ b/velox/functions/lib/MapFromEntries.cpp @@ -132,7 +132,7 @@ class MapFromEntriesFunction : public exec::VectorFunction { }); auto resetSize = [&](vector_size_t row) { mutableSizes[row] = 0; }; - auto nulls = allocateNulls(rows.size(), context.pool()); + auto nulls = allocateNulls(decodedRowVector->size(), context.pool()); auto* mutableNulls = nulls->asMutable(); if (decodedRowVector->mayHaveNulls() || keyVector->mayHaveNulls() || @@ -227,9 +227,13 @@ class MapFromEntriesFunction : public exec::VectorFunction { } // For Presto, need construct map vector based on input nulls for possible - // outer expression like try(). For Spark, use the updated nulls. + // outer expression like try(). For Spark, use the updated nulls unless it's empty. if constexpr (throwForNull) { nulls = inputArray->nulls(); + } else { + if (decodedRowVector->size() == 0) { + nulls = inputArray->nulls(); + } } auto mapVector = std::make_shared( context.pool(), diff --git a/velox/functions/sparksql/tests/MapFromEntriesTest.cpp b/velox/functions/sparksql/tests/MapFromEntriesTest.cpp index 578a1a06bb2b9..9c8226f8e9025 100644 --- a/velox/functions/sparksql/tests/MapFromEntriesTest.cpp +++ b/velox/functions/sparksql/tests/MapFromEntriesTest.cpp @@ -14,12 +14,9 @@ * limitations under the License. */ #include -#include #include "velox/common/base/tests/GTestUtils.h" -#include "velox/functions/lib/CheckDuplicateKeys.h" #include "velox/functions/prestosql/ArrayConstructor.h" #include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" -#include "velox/vector/tests/TestingDictionaryArrayElementsFunction.h" using namespace facebook::velox::test;