Skip to content

Commit

Permalink
use flag to control map_from_entries behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
yma11 committed Mar 2, 2024
1 parent e0b1941 commit 75799dc
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 26 deletions.
6 changes: 4 additions & 2 deletions velox/docs/functions/spark/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ Map Functions

SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')); -- {1.0 -> 2, 3.0 -> 4}

.. spark:function:: map_from_entries(struct(K,V)) -> map(K,V)
.. spark:function:: map_from_entries(array(struct(K,V))) -> map(K,V)
Converts an array of entries (key value struct types) to a map of values. All elements in keys should not be null.
If null entry exists in the array, return null for this whole array.::

SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))); -- {1 -> 'a', 2 -> 'b'}
SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'null'))); -- {1 -> 'a', 2 -> 'null'}
SELECT map_from_entries(array(struct(1, 'a'), null)); -- {1 -> 'a', 2 -> 'null'}


.. spark:function:: size(map(K,V)) -> bigint
:noindex:
Expand Down
1 change: 1 addition & 0 deletions velox/functions/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_library(
DateTimeFormatterBuilder.cpp
KllSketch.cpp
MapConcat.cpp
MapFromEntries.cpp
Re2Functions.cpp
Repeat.cpp
StringEncodingUtils.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ static const char* kIndeterminateKeyErrorMessage =
"map key cannot be indeterminate";
static const char* kErrorMessageEntryNotNull = "map entry cannot be null";

// See documentation at https://prestodb.io/docs/current/functions/map.html
class MapFromEntriesFunction : public exec::VectorFunction {
public:
// @param allowNullEle If true, will return null if map has
// an entry with null as key or map is null (Spark's behavior)
// instead of throw execeptions(Presto's behavior)
explicit MapFromEntriesFunction(bool allowNullEle)
: allowNullEle_(allowNullEle) {}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
Expand Down Expand Up @@ -94,14 +99,14 @@ class MapFromEntriesFunction : public exec::VectorFunction {
auto& inputValueVector = inputArray->elements();
exec::LocalDecodedVector decodedRowVector(context);
decodedRowVector.get()->decode(*inputValueVector);
// If the input array(unknown) then all rows should have errors.
if (inputValueVector->typeKind() == TypeKind::UNKNOWN) {
try {
VELOX_USER_FAIL(kErrorMessageEntryNotNull);
} catch (...) {
context.setErrors(rows, std::current_exception());
if (!allowNullEle_) {
try {
VELOX_USER_FAIL(kErrorMessageEntryNotNull);
} catch (...) {
context.setErrors(rows, std::current_exception());
}
}

auto sizes = allocateSizes(rows.end(), context.pool());
auto offsets = allocateSizes(rows.end(), context.pool());

Expand Down Expand Up @@ -129,8 +134,9 @@ class MapFromEntriesFunction : public exec::VectorFunction {
});

auto resetSize = [&](vector_size_t row) { mutableSizes[row] = 0; };
auto nulls = allocateNulls(decodedRowVector->size(), context.pool());
auto* mutableNulls = nulls->asMutable<uint64_t>();

// Validate all map entries and map keys are not null.
if (decodedRowVector->mayHaveNulls() || keyVector->mayHaveNulls() ||
keyVector->mayHaveNullsRecursive()) {
context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
Expand All @@ -140,9 +146,13 @@ class MapFromEntriesFunction : public exec::VectorFunction {
for (auto i = 0; i < size; ++i) {
// Check nulls in the top level row vector.
const bool isMapEntryNull = decodedRowVector->isNullAt(offset + i);
if (isMapEntryNull) {
// Set the sizes to 0 so that the final map vector generated is
// valid in case we are inside a try. The map vector needs to be
if (isMapEntryNull && allowNullEle_) {
// Spark: For nulls in the top level row vector, return null.
bits::setNull(mutableNulls, row);
break;
} else if (isMapEntryNull) {
// Presto: Set the sizes to 0 so that the final map vector generated
// is valid in case we are inside a try. The map vector needs to be
// valid because its consumed by checkDuplicateKeys before try
// sets invalid rows to null.
resetSize(row);
Expand Down Expand Up @@ -200,8 +210,6 @@ class MapFromEntriesFunction : public exec::VectorFunction {
} else {
// Dictionary.
auto indices = allocateIndices(decodedRowVector->size(), context.pool());
auto nulls = allocateNulls(decodedRowVector->size(), context.pool());
auto* mutableNulls = nulls->asMutable<uint64_t>();
memcpy(
indices->asMutable<vector_size_t>(),
decodedRowVector->indices(),
Expand All @@ -224,7 +232,7 @@ class MapFromEntriesFunction : public exec::VectorFunction {
auto mapVector = std::make_shared<MapVector>(
context.pool(),
outputType,
inputArray->nulls(),
nulls,
rows.end(),
inputArray->offsets(),
sizes,
Expand All @@ -234,11 +242,21 @@ class MapFromEntriesFunction : public exec::VectorFunction {
checkDuplicateKeys(mapVector, *remianingRows, context);
return mapVector;
}
const bool allowNullEle_;
};
} // namespace

VELOX_DECLARE_VECTOR_FUNCTION(
udf_map_from_entries,
MapFromEntriesFunction::signatures(),
std::make_unique<MapFromEntriesFunction>());
void registerMapFromEntriesFunction(const std::string& name) {
exec::registerVectorFunction(
name,
MapFromEntriesFunction::signatures(),
std::make_unique<MapFromEntriesFunction>(/*AllowNullEle=*/false));
}

void registerMapFromEntriesAllowNullEleFunction(const std::string& name) {
exec::registerVectorFunction(
name,
MapFromEntriesFunction::signatures(),
std::make_unique<MapFromEntriesFunction>(/*AllowNullEle=*/true));
}
} // namespace facebook::velox::functions
27 changes: 27 additions & 0 deletions velox/functions/lib/MapFromEntries.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <string>

namespace facebook::velox::functions {

void registerMapFromEntriesFunction(const std::string& name);

void registerMapFromEntriesAllowNullEleFunction(const std::string& name);

} // namespace facebook::velox::functions
1 change: 0 additions & 1 deletion velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ add_library(
JsonFunctions.cpp
Map.cpp
MapEntries.cpp
MapFromEntries.cpp
MapKeysAndValues.cpp
MapZipWith.cpp
Not.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "velox/expression/VectorFunction.h"
#include "velox/functions/Registerer.h"
#include "velox/functions/lib/MapConcat.h"
#include "velox/functions/lib/MapFromEntries.h"
#include "velox/functions/prestosql/MultimapFromEntries.h"

namespace facebook::velox::functions {
Expand All @@ -27,8 +28,8 @@ void registerMapFunctions(const std::string& prefix) {
udf_transform_values, prefix + "transform_values");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map, prefix + "map");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_entries, prefix + "map_entries");
VELOX_REGISTER_VECTOR_FUNCTION(
udf_map_from_entries, prefix + "map_from_entries");
registerMapFromEntriesFunction(prefix + "map_from_entries");

VELOX_REGISTER_VECTOR_FUNCTION(udf_map_keys, prefix + "map_keys");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_values, prefix + "map_values");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_zip_with, prefix + "map_zip_with");
Expand Down
1 change: 1 addition & 0 deletions velox/functions/prestosql/tests/MapFromEntriesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <optional>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/lib/CheckDuplicateKeys.h"
// #include "velox/functions/lib/CheckDuplicateKeys.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/vector/tests/TestingDictionaryArrayElementsFunction.h"

Expand Down
1 change: 0 additions & 1 deletion velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ add_library(
In.cpp
LeastGreatest.cpp
Map.cpp
MapFromEntries.cpp
RegexFunctions.cpp
Register.cpp
RegisterArithmetic.cpp
Expand Down
4 changes: 2 additions & 2 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "velox/expression/RowConstructor.h"
#include "velox/expression/SpecialFormRegistry.h"
#include "velox/functions/lib/IsNull.h"
#include "velox/functions/lib/MapFromEntries.h"
#include "velox/functions/lib/Re2Functions.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/lib/Repeat.h"
Expand Down Expand Up @@ -69,8 +70,7 @@ static void workAroundRegistrationMacro(const std::string& prefix) {

VELOX_REGISTER_VECTOR_FUNCTION(
udf_map_allow_duplicates, prefix + "map_from_arrays");
VELOX_REGISTER_VECTOR_FUNCTION(
udf_map_from_entries, prefix + "map_from_entries");
registerMapFromEntriesAllowNullEleFunction(prefix + "map_from_entries");
VELOX_REGISTER_VECTOR_FUNCTION(
udf_concat_row, exec::RowConstructorCallToSpecialForm::kRowConstructor);
// String functions.
Expand Down

0 comments on commit 75799dc

Please sign in to comment.