Skip to content

Commit

Permalink
Add map_from_entries Spark function
Browse files Browse the repository at this point in the history
  • Loading branch information
yma11 committed Feb 20, 2024
1 parent d55c48b commit 66831a8
Show file tree
Hide file tree
Showing 6 changed files with 700 additions and 0 deletions.
7 changes: 7 additions & 0 deletions velox/docs/functions/spark/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ Map Functions

SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')); -- {1.0 -> 2, 3.0 -> 4}

.. spark:function:: map_from_entries(struct(K,V)) -> map(K,V)
Converts an array of entries (key value struct types) to a map of values. All elements in keys should not be null.
If null entry exists in the array, return null for this whole array.::

SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))); -- {1 -> 'a', 2 -> 'b'}

.. spark:function:: size(map(K,V)) -> bigint
:noindex:

Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_library(
In.cpp
LeastGreatest.cpp
Map.cpp
MapFromEntries.cpp
RegexFunctions.cpp
Register.cpp
RegisterArithmetic.cpp
Expand Down
230 changes: 230 additions & 0 deletions velox/functions/sparksql/MapFromEntries.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>

#include "velox/expression/EvalCtx.h"
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/CheckDuplicateKeys.h"
#include "velox/functions/lib/RowsTranslationUtil.h"
#include "velox/vector/BaseVector.h"
#include "velox/vector/ComplexVector.h"

namespace facebook::velox::functions {
namespace {
static const char* kNullKeyErrorMessage = "map key cannot be null";
static const char* kIndeterminateKeyErrorMessage =
"map key cannot be indeterminate";

class MapFromEntriesFunction : public exec::VectorFunction {
public:
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
VELOX_CHECK_EQ(args.size(), 1);
auto& arg = args[0];
VectorPtr localResult;
// Input can be constant or flat.
if (arg->isConstantEncoding()) {
auto* constantArray = arg->as<ConstantVector<ComplexType>>();
const auto& flatArray = constantArray->valueVector();
const auto flatIndex = constantArray->index();

exec::LocalSelectivityVector singleRow(context, flatIndex + 1);
singleRow->clearAll();
singleRow->setValid(flatIndex, true);
singleRow->updateBounds();

localResult = applyFlat(
*singleRow.get(), flatArray->as<ArrayVector>(), outputType, context);
localResult =
BaseVector::wrapInConstant(rows.size(), flatIndex, localResult);
} else {
localResult =
applyFlat(rows, arg->as<ArrayVector>(), outputType, context);
}

context.moveOrCopyResult(localResult, rows, result);
}

static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
return {// unknown -> map(unknown, unknown)
exec::FunctionSignatureBuilder()
.returnType("map(unknown, unknown)")
.argumentType("unknown")
.build(),
// array(unknown) -> map(unknown, unknown)
exec::FunctionSignatureBuilder()
.returnType("map(unknown, unknown)")
.argumentType("array(unknown)")
.build(),
// array(row(K,V)) -> map(K,V)
exec::FunctionSignatureBuilder()
.typeVariable("K")
.typeVariable("V")
.returnType("map(K,V)")
.argumentType("array(row(K,V))")
.build()};
}

private:
VectorPtr applyFlat(
const SelectivityVector& rows,
const ArrayVector* inputArray,
const TypePtr& outputType,
exec::EvalCtx& context) const {
auto& inputValueVector = inputArray->elements();
exec::LocalDecodedVector decodedRowVector(context);
decodedRowVector.get()->decode(*inputValueVector);
if (inputValueVector->typeKind() == TypeKind::UNKNOWN) {
auto sizes = allocateSizes(rows.end(), context.pool());
auto offsets = allocateSizes(rows.end(), context.pool());

// Output in this case is map(unknown, unknown), but all elements are
// nulls, all offsets and sizes are 0.
return std::make_shared<MapVector>(
context.pool(),
outputType,
inputArray->nulls(),
rows.end(),
sizes,
offsets,
BaseVector::create(UNKNOWN(), 0, context.pool()),
BaseVector::create(UNKNOWN(), 0, context.pool()));
}

exec::LocalSelectivityVector remianingRows(context, rows);
auto rowVector = decodedRowVector->base()->as<RowVector>();
auto keyVector = rowVector->childAt(0);

BufferPtr sizes = allocateSizes(rows.end(), context.pool());
vector_size_t* mutableSizes = sizes->asMutable<vector_size_t>();
rows.applyToSelected([&](vector_size_t row) {
mutableSizes[row] = inputArray->rawSizes()[row];
});

auto resetSize = [&](vector_size_t row) { mutableSizes[row] = 0; };
auto nulls = allocateNulls(decodedRowVector->size(), context.pool());
auto* mutableNulls = nulls->asMutable<uint64_t>();

if (decodedRowVector->mayHaveNulls() || keyVector->mayHaveNulls() ||
keyVector->mayHaveNullsRecursive()) {
context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
const auto size = inputArray->sizeAt(row);
const auto offset = inputArray->offsetAt(row);

for (auto i = 0; i < size; ++i) {
// For nulls in the top level row vector, return null.
const bool isMapEntryNull = decodedRowVector->isNullAt(offset + i);
if (isMapEntryNull) {
bits::setNull(mutableNulls, row);
break;
}

// Check null keys.
auto keyIndex = decodedRowVector->index(offset + i);
if (keyVector->isNullAt(keyIndex)) {
resetSize(row);
VELOX_USER_FAIL(kNullKeyErrorMessage);
}

// Check nested null in keys.
if (keyVector->containsNullAt(keyIndex)) {
resetSize(row);
VELOX_USER_FAIL(fmt::format(
"{}: {}",
kIndeterminateKeyErrorMessage,
keyVector->toString(keyIndex)));
}
}
});
}

context.deselectErrors(*remianingRows.get());

VectorPtr wrappedKeys;
VectorPtr wrappedValues;
if (decodedRowVector->isIdentityMapping()) {
wrappedKeys = rowVector->childAt(0);
wrappedValues = rowVector->childAt(1);
} else if (decodedRowVector->isConstantMapping()) {
if (decodedRowVector->isNullAt(0)) {
// If top level row is null, child might not be addressable at index 0
// so we do not try to read it.
wrappedKeys = BaseVector::createNullConstant(
rowVector->childAt(0)->type(),
decodedRowVector->size(),
context.pool());
wrappedValues = BaseVector::createNullConstant(
rowVector->childAt(1)->type(),
decodedRowVector->size(),
context.pool());
} else {
wrappedKeys = BaseVector::wrapInConstant(
decodedRowVector->size(),
decodedRowVector->index(0),
rowVector->childAt(0));
wrappedValues = BaseVector::wrapInConstant(
decodedRowVector->size(),
decodedRowVector->index(0),
rowVector->childAt(1));
}
} else {
// Dictionary.
auto indices = allocateIndices(decodedRowVector->size(), context.pool());
memcpy(
indices->asMutable<vector_size_t>(),
decodedRowVector->indices(),
BaseVector::byteSize<vector_size_t>(decodedRowVector->size()));
// Any null in the top row(X, Y) should be marked as null since its
// not guranteed to be addressable at X or Y.
for (auto i = 0; i < decodedRowVector->size(); i++) {
if (decodedRowVector->isNullAt(i)) {
bits::setNull(mutableNulls, i);
}
}
wrappedKeys = BaseVector::wrapInDictionary(
nulls, indices, decodedRowVector->size(), rowVector->childAt(0));
wrappedValues = BaseVector::wrapInDictionary(
nulls, indices, decodedRowVector->size(), rowVector->childAt(1));
}

// To avoid creating new buffers, we try to reuse the input's buffers
// as many as possible.
auto mapVector = std::make_shared<MapVector>(
context.pool(),
outputType,
nulls,
rows.end(),
inputArray->offsets(),
sizes,
wrappedKeys,
wrappedValues);

checkDuplicateKeys(mapVector, *remianingRows, context);
return mapVector;
}
};
} // namespace

VELOX_DECLARE_VECTOR_FUNCTION(
udf_map_from_entries,
MapFromEntriesFunction::signatures(),
std::make_unique<MapFromEntriesFunction>());
} // namespace facebook::velox::functions
2 changes: 2 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ static void workAroundRegistrationMacro(const std::string& prefix) {

VELOX_REGISTER_VECTOR_FUNCTION(
udf_map_allow_duplicates, prefix + "map_from_arrays");
VELOX_REGISTER_VECTOR_FUNCTION(
udf_map_from_entries, prefix + "map_from_entries");
VELOX_REGISTER_VECTOR_FUNCTION(
udf_concat_row, exec::RowConstructorCallToSpecialForm::kRowConstructor);
// String functions.
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_executable(
InTest.cpp
LeastGreatestTest.cpp
MapTest.cpp
MapFromEntriesTest.cpp
MightContainTest.cpp
RandTest.cpp
RegexFunctionsTest.cpp
Expand Down
Loading

0 comments on commit 66831a8

Please sign in to comment.