Skip to content

Commit

Permalink
use flag to control map_from_entries behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
yma11 committed Apr 1, 2024
1 parent 3b7fd2f commit 73328d6
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 258 deletions.
6 changes: 4 additions & 2 deletions velox/docs/functions/spark/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ Map Functions

SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')); -- {1.0 -> 2, 3.0 -> 4}

.. spark:function:: map_from_entries(struct(K,V)) -> map(K,V)
.. spark:function:: map_from_entries(array(struct(K,V))) -> map(K,V)
Converts an array of entries (key value struct types) to a map of values. All elements in keys should not be null.
If null entry exists in the array, return null for this whole array.::

SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))); -- {1 -> 'a', 2 -> 'b'}
SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'null'))); -- {1 -> 'a', 2 -> 'null'}
SELECT map_from_entries(array(struct(1, 'a'), null)); -- {null}


.. spark:function:: size(map(K,V)) -> bigint
:noindex:
Expand Down
1 change: 1 addition & 0 deletions velox/functions/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_library(
DateTimeFormatterBuilder.cpp
KllSketch.cpp
MapConcat.cpp
MapFromEntries.cpp
Re2Functions.cpp
Repeat.cpp
StringEncodingUtils.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ static const char* kIndeterminateKeyErrorMessage =
"map key cannot be indeterminate";
static const char* kErrorMessageEntryNotNull = "map entry cannot be null";

// See documentation at https://prestodb.io/docs/current/functions/map.html
// allowNullEle If true, will return null if map has
// an entry with null as key or map is null (Spark's behavior)
// instead of throw execeptions(Presto's behavior)
template <bool allowNullEle>
class MapFromEntriesFunction : public exec::VectorFunction {
public:
void apply(
Expand Down Expand Up @@ -94,14 +97,14 @@ class MapFromEntriesFunction : public exec::VectorFunction {
auto& inputValueVector = inputArray->elements();
exec::LocalDecodedVector decodedRowVector(context);
decodedRowVector.get()->decode(*inputValueVector);
// If the input array(unknown) then all rows should have errors.
if (inputValueVector->typeKind() == TypeKind::UNKNOWN) {
try {
VELOX_USER_FAIL(kErrorMessageEntryNotNull);
} catch (...) {
context.setErrors(rows, std::current_exception());
if (!allowNullEle) {
try {
VELOX_USER_FAIL(kErrorMessageEntryNotNull);
} catch (...) {
context.setErrors(rows, std::current_exception());
}
}

auto sizes = allocateSizes(rows.end(), context.pool());
auto offsets = allocateSizes(rows.end(), context.pool());

Expand Down Expand Up @@ -129,8 +132,9 @@ class MapFromEntriesFunction : public exec::VectorFunction {
});

auto resetSize = [&](vector_size_t row) { mutableSizes[row] = 0; };
auto nulls = allocateNulls(decodedRowVector->size(), context.pool());
auto* mutableNulls = nulls->asMutable<uint64_t>();

// Validate all map entries and map keys are not null.
if (decodedRowVector->mayHaveNulls() || keyVector->mayHaveNulls() ||
keyVector->mayHaveNullsRecursive()) {
context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
Expand All @@ -140,9 +144,14 @@ class MapFromEntriesFunction : public exec::VectorFunction {
for (auto i = 0; i < size; ++i) {
// Check nulls in the top level row vector.
const bool isMapEntryNull = decodedRowVector->isNullAt(offset + i);
if (isMapEntryNull) {
// Set the sizes to 0 so that the final map vector generated is
// valid in case we are inside a try. The map vector needs to be
if (isMapEntryNull && allowNullEle) {
// Spark: For nulls in the top level row vector, return null.
bits::setNull(mutableNulls, row);
resetSize(row);
break;
} else if (isMapEntryNull) {
// Presto: Set the sizes to 0 so that the final map vector generated
// is valid in case we are inside a try. The map vector needs to be
// valid because its consumed by checkDuplicateKeys before try
// sets invalid rows to null.
resetSize(row);
Expand Down Expand Up @@ -200,8 +209,6 @@ class MapFromEntriesFunction : public exec::VectorFunction {
} else {
// Dictionary.
auto indices = allocateIndices(decodedRowVector->size(), context.pool());
auto nulls = allocateNulls(decodedRowVector->size(), context.pool());
auto* mutableNulls = nulls->asMutable<uint64_t>();
memcpy(
indices->asMutable<vector_size_t>(),
decodedRowVector->indices(),
Expand All @@ -219,12 +226,13 @@ class MapFromEntriesFunction : public exec::VectorFunction {
nulls, indices, decodedRowVector->size(), rowVector->childAt(1));
}

// To avoid creating new buffers, we try to reuse the input's buffers
// as many as possible.
// For Presto, need construct map vector based on input nulls for possible
// outer expression like try(). For Spark, use the updated nulls.
auto mapVetorNulls = allowNullEle ? nulls : inputArray->nulls();
auto mapVector = std::make_shared<MapVector>(
context.pool(),
outputType,
inputArray->nulls(),
mapVetorNulls,
rows.end(),
inputArray->offsets(),
sizes,
Expand All @@ -237,8 +245,17 @@ class MapFromEntriesFunction : public exec::VectorFunction {
};
} // namespace

VELOX_DECLARE_VECTOR_FUNCTION(
udf_map_from_entries,
MapFromEntriesFunction::signatures(),
std::make_unique<MapFromEntriesFunction>());
void registerMapFromEntriesFunction(const std::string& name) {
exec::registerVectorFunction(
name,
MapFromEntriesFunction</*AllowNullEle=*/false>::signatures(),
std::make_unique<MapFromEntriesFunction</*AllowNullEle=*/false>>());
}

void registerMapFromEntriesAllowNullEleFunction(const std::string& name) {
exec::registerVectorFunction(
name,
MapFromEntriesFunction</*AllowNullEle=*/true>::signatures(),
std::make_unique<MapFromEntriesFunction</*AllowNullEle=*/true>>());
}
} // namespace facebook::velox::functions
27 changes: 27 additions & 0 deletions velox/functions/lib/MapFromEntries.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <string>

namespace facebook::velox::functions {

void registerMapFromEntriesFunction(const std::string& name);

void registerMapFromEntriesAllowNullEleFunction(const std::string& name);

} // namespace facebook::velox::functions
1 change: 0 additions & 1 deletion velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ add_library(
JsonFunctions.cpp
Map.cpp
MapEntries.cpp
MapFromEntries.cpp
MapKeysAndValues.cpp
MapZipWith.cpp
Not.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "velox/expression/VectorFunction.h"
#include "velox/functions/Registerer.h"
#include "velox/functions/lib/MapConcat.h"
#include "velox/functions/lib/MapFromEntries.h"
#include "velox/functions/prestosql/MapNormalize.h"
#include "velox/functions/prestosql/MapSubset.h"
#include "velox/functions/prestosql/MapTopN.h"
Expand Down Expand Up @@ -68,8 +69,8 @@ void registerMapFunctions(const std::string& prefix) {
udf_transform_values, prefix + "transform_values");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map, prefix + "map");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_entries, prefix + "map_entries");
VELOX_REGISTER_VECTOR_FUNCTION(
udf_map_from_entries, prefix + "map_from_entries");
registerMapFromEntriesFunction(prefix + "map_from_entries");

VELOX_REGISTER_VECTOR_FUNCTION(udf_map_keys, prefix + "map_keys");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_values, prefix + "map_values");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_zip_with, prefix + "map_zip_with");
Expand Down
1 change: 1 addition & 0 deletions velox/functions/prestosql/tests/MapFromEntriesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <optional>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/lib/CheckDuplicateKeys.h"
// #include "velox/functions/lib/CheckDuplicateKeys.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/vector/tests/TestingDictionaryArrayElementsFunction.h"

Expand Down
1 change: 0 additions & 1 deletion velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ add_library(
LeastGreatest.cpp
MakeTimestamp.cpp
Map.cpp
MapFromEntries.cpp
RegexFunctions.cpp
Register.cpp
RegisterArithmetic.cpp
Expand Down
Loading

0 comments on commit 73328d6

Please sign in to comment.