From 75c7f1ad6d22323d9674c7408f90c45767a09a20 Mon Sep 17 00:00:00 2001 From: yan ma Date: Sun, 3 Mar 2024 01:06:05 +0800 Subject: [PATCH] use flag to control map_from_entries behavior --- velox/docs/functions/spark/map.rst | 6 +- velox/functions/lib/CMakeLists.txt | 1 + .../{prestosql => lib}/MapFromEntries.cpp | 57 +++-- velox/functions/lib/MapFromEntries.h | 27 ++ velox/functions/prestosql/CMakeLists.txt | 1 - .../registration/MapFunctionsRegistration.cpp | 5 +- .../prestosql/tests/MapFromEntriesTest.cpp | 1 + velox/functions/sparksql/CMakeLists.txt | 1 - velox/functions/sparksql/MapFromEntries.cpp | 230 ------------------ velox/functions/sparksql/Register.cpp | 4 +- 10 files changed, 75 insertions(+), 258 deletions(-) rename velox/functions/{prestosql => lib}/MapFromEntries.cpp (82%) create mode 100644 velox/functions/lib/MapFromEntries.h delete mode 100644 velox/functions/sparksql/MapFromEntries.cpp diff --git a/velox/docs/functions/spark/map.rst b/velox/docs/functions/spark/map.rst index 13842847271f..aceabe402f05 100644 --- a/velox/docs/functions/spark/map.rst +++ b/velox/docs/functions/spark/map.rst @@ -27,12 +27,14 @@ Map Functions SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')); -- {1.0 -> 2, 3.0 -> 4} -.. spark:function:: map_from_entries(struct(K,V)) -> map(K,V) +.. spark:function:: map_from_entries(array(struct(K,V))) -> map(K,V) Converts an array of entries (key value struct types) to a map of values. All elements in keys should not be null. If null entry exists in the array, return null for this whole array.:: - SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))); -- {1 -> 'a', 2 -> 'b'} + SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'null'))); -- {1 -> 'a', 2 -> 'null'} + SELECT map_from_entries(array(struct(1, 'a'), null)); -- {null} + .. spark:function:: size(map(K,V)) -> bigint :noindex: diff --git a/velox/functions/lib/CMakeLists.txt b/velox/functions/lib/CMakeLists.txt index fe1bdb405ff9..2a3a4391e699 100644 --- a/velox/functions/lib/CMakeLists.txt +++ b/velox/functions/lib/CMakeLists.txt @@ -27,6 +27,7 @@ add_library( DateTimeFormatterBuilder.cpp KllSketch.cpp MapConcat.cpp + MapFromEntries.cpp Re2Functions.cpp Repeat.cpp StringEncodingUtils.cpp diff --git a/velox/functions/prestosql/MapFromEntries.cpp b/velox/functions/lib/MapFromEntries.cpp similarity index 82% rename from velox/functions/prestosql/MapFromEntries.cpp rename to velox/functions/lib/MapFromEntries.cpp index c6700b0dbaf9..226304391053 100644 --- a/velox/functions/prestosql/MapFromEntries.cpp +++ b/velox/functions/lib/MapFromEntries.cpp @@ -30,7 +30,10 @@ static const char* kIndeterminateKeyErrorMessage = "map key cannot be indeterminate"; static const char* kErrorMessageEntryNotNull = "map entry cannot be null"; -// See documentation at https://prestodb.io/docs/current/functions/map.html +// allowNullEle If true, will return null if map has +// an entry with null as key or map is null (Spark's behavior) +// instead of throw execeptions(Presto's behavior) +template class MapFromEntriesFunction : public exec::VectorFunction { public: void apply( @@ -94,14 +97,14 @@ class MapFromEntriesFunction : public exec::VectorFunction { auto& inputValueVector = inputArray->elements(); exec::LocalDecodedVector decodedRowVector(context); decodedRowVector.get()->decode(*inputValueVector); - // If the input array(unknown) then all rows should have errors. if (inputValueVector->typeKind() == TypeKind::UNKNOWN) { - try { - VELOX_USER_FAIL(kErrorMessageEntryNotNull); - } catch (...) { - context.setErrors(rows, std::current_exception()); + if (!allowNullEle) { + try { + VELOX_USER_FAIL(kErrorMessageEntryNotNull); + } catch (...) { + context.setErrors(rows, std::current_exception()); + } } - auto sizes = allocateSizes(rows.end(), context.pool()); auto offsets = allocateSizes(rows.end(), context.pool()); @@ -129,8 +132,9 @@ class MapFromEntriesFunction : public exec::VectorFunction { }); auto resetSize = [&](vector_size_t row) { mutableSizes[row] = 0; }; + auto nulls = allocateNulls(decodedRowVector->size(), context.pool()); + auto* mutableNulls = nulls->asMutable(); - // Validate all map entries and map keys are not null. if (decodedRowVector->mayHaveNulls() || keyVector->mayHaveNulls() || keyVector->mayHaveNullsRecursive()) { context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { @@ -140,9 +144,14 @@ class MapFromEntriesFunction : public exec::VectorFunction { for (auto i = 0; i < size; ++i) { // Check nulls in the top level row vector. const bool isMapEntryNull = decodedRowVector->isNullAt(offset + i); - if (isMapEntryNull) { - // Set the sizes to 0 so that the final map vector generated is - // valid in case we are inside a try. The map vector needs to be + if (isMapEntryNull && allowNullEle) { + // Spark: For nulls in the top level row vector, return null. + bits::setNull(mutableNulls, row); + resetSize(row); + break; + } else if (isMapEntryNull) { + // Presto: Set the sizes to 0 so that the final map vector generated + // is valid in case we are inside a try. The map vector needs to be // valid because its consumed by checkDuplicateKeys before try // sets invalid rows to null. resetSize(row); @@ -200,8 +209,6 @@ class MapFromEntriesFunction : public exec::VectorFunction { } else { // Dictionary. auto indices = allocateIndices(decodedRowVector->size(), context.pool()); - auto nulls = allocateNulls(decodedRowVector->size(), context.pool()); - auto* mutableNulls = nulls->asMutable(); memcpy( indices->asMutable(), decodedRowVector->indices(), @@ -219,12 +226,13 @@ class MapFromEntriesFunction : public exec::VectorFunction { nulls, indices, decodedRowVector->size(), rowVector->childAt(1)); } - // To avoid creating new buffers, we try to reuse the input's buffers - // as many as possible. + // For Presto, need construct map vector based on input nulls for possible + // outer expression like try(). For Spark, use the updated nulls. + auto mapVetorNulls = allowNullEle ? nulls : inputArray->nulls(); auto mapVector = std::make_shared( context.pool(), outputType, - inputArray->nulls(), + mapVetorNulls, rows.end(), inputArray->offsets(), sizes, @@ -237,8 +245,17 @@ class MapFromEntriesFunction : public exec::VectorFunction { }; } // namespace -VELOX_DECLARE_VECTOR_FUNCTION( - udf_map_from_entries, - MapFromEntriesFunction::signatures(), - std::make_unique()); +void registerMapFromEntriesFunction(const std::string& name) { + exec::registerVectorFunction( + name, + MapFromEntriesFunction::signatures(), + std::make_unique>()); +} + +void registerMapFromEntriesAllowNullEleFunction(const std::string& name) { + exec::registerVectorFunction( + name, + MapFromEntriesFunction::signatures(), + std::make_unique>()); +} } // namespace facebook::velox::functions diff --git a/velox/functions/lib/MapFromEntries.h b/velox/functions/lib/MapFromEntries.h new file mode 100644 index 000000000000..02860d62e4f0 --- /dev/null +++ b/velox/functions/lib/MapFromEntries.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace facebook::velox::functions { + +void registerMapFromEntriesFunction(const std::string& name); + +void registerMapFromEntriesAllowNullEleFunction(const std::string& name); + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index 8c17ce34ad0c..c41fc7772e8a 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -39,7 +39,6 @@ add_library( JsonFunctions.cpp Map.cpp MapEntries.cpp - MapFromEntries.cpp MapKeysAndValues.cpp MapZipWith.cpp Not.cpp diff --git a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp index fca2990a4aa7..6c811eafb1e0 100644 --- a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp @@ -17,6 +17,7 @@ #include "velox/expression/VectorFunction.h" #include "velox/functions/Registerer.h" #include "velox/functions/lib/MapConcat.h" +#include "velox/functions/lib/MapFromEntries.h" #include "velox/functions/prestosql/MapNormalize.h" #include "velox/functions/prestosql/MapSubset.h" #include "velox/functions/prestosql/MapTopN.h" @@ -41,8 +42,8 @@ void registerMapFunctions(const std::string& prefix) { udf_transform_values, prefix + "transform_values"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map, prefix + "map"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map_entries, prefix + "map_entries"); - VELOX_REGISTER_VECTOR_FUNCTION( - udf_map_from_entries, prefix + "map_from_entries"); + registerMapFromEntriesFunction(prefix + "map_from_entries"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_map_keys, prefix + "map_keys"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map_values, prefix + "map_values"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map_zip_with, prefix + "map_zip_with"); diff --git a/velox/functions/prestosql/tests/MapFromEntriesTest.cpp b/velox/functions/prestosql/tests/MapFromEntriesTest.cpp index adf14fd87b58..1a1fdc2d0f36 100644 --- a/velox/functions/prestosql/tests/MapFromEntriesTest.cpp +++ b/velox/functions/prestosql/tests/MapFromEntriesTest.cpp @@ -17,6 +17,7 @@ #include #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/lib/CheckDuplicateKeys.h" +// #include "velox/functions/lib/CheckDuplicateKeys.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" #include "velox/vector/tests/TestingDictionaryArrayElementsFunction.h" diff --git a/velox/functions/sparksql/CMakeLists.txt b/velox/functions/sparksql/CMakeLists.txt index 941fa9a53c7e..6ba9c03c6c79 100644 --- a/velox/functions/sparksql/CMakeLists.txt +++ b/velox/functions/sparksql/CMakeLists.txt @@ -25,7 +25,6 @@ add_library( LeastGreatest.cpp MakeTimestamp.cpp Map.cpp - MapFromEntries.cpp RegexFunctions.cpp Register.cpp RegisterArithmetic.cpp diff --git a/velox/functions/sparksql/MapFromEntries.cpp b/velox/functions/sparksql/MapFromEntries.cpp deleted file mode 100644 index ee7f0075c6fb..000000000000 --- a/velox/functions/sparksql/MapFromEntries.cpp +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "velox/expression/EvalCtx.h" -#include "velox/expression/Expr.h" -#include "velox/expression/VectorFunction.h" -#include "velox/functions/lib/CheckDuplicateKeys.h" -#include "velox/functions/lib/RowsTranslationUtil.h" -#include "velox/vector/BaseVector.h" -#include "velox/vector/ComplexVector.h" - -namespace facebook::velox::functions { -namespace { -static const char* kNullKeyErrorMessage = "map key cannot be null"; -static const char* kIndeterminateKeyErrorMessage = - "map key cannot be indeterminate"; - -class MapFromEntriesFunction : public exec::VectorFunction { - public: - void apply( - const SelectivityVector& rows, - std::vector& args, - const TypePtr& outputType, - exec::EvalCtx& context, - VectorPtr& result) const override { - VELOX_CHECK_EQ(args.size(), 1); - auto& arg = args[0]; - VectorPtr localResult; - // Input can be constant or flat. - if (arg->isConstantEncoding()) { - auto* constantArray = arg->as>(); - const auto& flatArray = constantArray->valueVector(); - const auto flatIndex = constantArray->index(); - - exec::LocalSelectivityVector singleRow(context, flatIndex + 1); - singleRow->clearAll(); - singleRow->setValid(flatIndex, true); - singleRow->updateBounds(); - - localResult = applyFlat( - *singleRow.get(), flatArray->as(), outputType, context); - localResult = - BaseVector::wrapInConstant(rows.size(), flatIndex, localResult); - } else { - localResult = - applyFlat(rows, arg->as(), outputType, context); - } - - context.moveOrCopyResult(localResult, rows, result); - } - - static std::vector> signatures() { - return {// unknown -> map(unknown, unknown) - exec::FunctionSignatureBuilder() - .returnType("map(unknown, unknown)") - .argumentType("unknown") - .build(), - // array(unknown) -> map(unknown, unknown) - exec::FunctionSignatureBuilder() - .returnType("map(unknown, unknown)") - .argumentType("array(unknown)") - .build(), - // array(row(K,V)) -> map(K,V) - exec::FunctionSignatureBuilder() - .typeVariable("K") - .typeVariable("V") - .returnType("map(K,V)") - .argumentType("array(row(K,V))") - .build()}; - } - - private: - VectorPtr applyFlat( - const SelectivityVector& rows, - const ArrayVector* inputArray, - const TypePtr& outputType, - exec::EvalCtx& context) const { - auto& inputValueVector = inputArray->elements(); - exec::LocalDecodedVector decodedRowVector(context); - decodedRowVector.get()->decode(*inputValueVector); - if (inputValueVector->typeKind() == TypeKind::UNKNOWN) { - auto sizes = allocateSizes(rows.end(), context.pool()); - auto offsets = allocateSizes(rows.end(), context.pool()); - - // Output in this case is map(unknown, unknown), but all elements are - // nulls, all offsets and sizes are 0. - return std::make_shared( - context.pool(), - outputType, - inputArray->nulls(), - rows.end(), - sizes, - offsets, - BaseVector::create(UNKNOWN(), 0, context.pool()), - BaseVector::create(UNKNOWN(), 0, context.pool())); - } - - exec::LocalSelectivityVector remianingRows(context, rows); - auto rowVector = decodedRowVector->base()->as(); - auto keyVector = rowVector->childAt(0); - - BufferPtr sizes = allocateSizes(rows.end(), context.pool()); - vector_size_t* mutableSizes = sizes->asMutable(); - rows.applyToSelected([&](vector_size_t row) { - mutableSizes[row] = inputArray->rawSizes()[row]; - }); - - auto resetSize = [&](vector_size_t row) { mutableSizes[row] = 0; }; - auto nulls = allocateNulls(decodedRowVector->size(), context.pool()); - auto* mutableNulls = nulls->asMutable(); - - if (decodedRowVector->mayHaveNulls() || keyVector->mayHaveNulls() || - keyVector->mayHaveNullsRecursive()) { - context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { - const auto size = inputArray->sizeAt(row); - const auto offset = inputArray->offsetAt(row); - - for (auto i = 0; i < size; ++i) { - // For nulls in the top level row vector, return null. - const bool isMapEntryNull = decodedRowVector->isNullAt(offset + i); - if (isMapEntryNull) { - bits::setNull(mutableNulls, row); - break; - } - - // Check null keys. - auto keyIndex = decodedRowVector->index(offset + i); - if (keyVector->isNullAt(keyIndex)) { - resetSize(row); - VELOX_USER_FAIL(kNullKeyErrorMessage); - } - - // Check nested null in keys. - if (keyVector->containsNullAt(keyIndex)) { - resetSize(row); - VELOX_USER_FAIL(fmt::format( - "{}: {}", - kIndeterminateKeyErrorMessage, - keyVector->toString(keyIndex))); - } - } - }); - } - - context.deselectErrors(*remianingRows.get()); - - VectorPtr wrappedKeys; - VectorPtr wrappedValues; - if (decodedRowVector->isIdentityMapping()) { - wrappedKeys = rowVector->childAt(0); - wrappedValues = rowVector->childAt(1); - } else if (decodedRowVector->isConstantMapping()) { - if (decodedRowVector->isNullAt(0)) { - // If top level row is null, child might not be addressable at index 0 - // so we do not try to read it. - wrappedKeys = BaseVector::createNullConstant( - rowVector->childAt(0)->type(), - decodedRowVector->size(), - context.pool()); - wrappedValues = BaseVector::createNullConstant( - rowVector->childAt(1)->type(), - decodedRowVector->size(), - context.pool()); - } else { - wrappedKeys = BaseVector::wrapInConstant( - decodedRowVector->size(), - decodedRowVector->index(0), - rowVector->childAt(0)); - wrappedValues = BaseVector::wrapInConstant( - decodedRowVector->size(), - decodedRowVector->index(0), - rowVector->childAt(1)); - } - } else { - // Dictionary. - auto indices = allocateIndices(decodedRowVector->size(), context.pool()); - memcpy( - indices->asMutable(), - decodedRowVector->indices(), - BaseVector::byteSize(decodedRowVector->size())); - // Any null in the top row(X, Y) should be marked as null since its - // not guranteed to be addressable at X or Y. - for (auto i = 0; i < decodedRowVector->size(); i++) { - if (decodedRowVector->isNullAt(i)) { - bits::setNull(mutableNulls, i); - } - } - wrappedKeys = BaseVector::wrapInDictionary( - nulls, indices, decodedRowVector->size(), rowVector->childAt(0)); - wrappedValues = BaseVector::wrapInDictionary( - nulls, indices, decodedRowVector->size(), rowVector->childAt(1)); - } - - // To avoid creating new buffers, we try to reuse the input's buffers - // as many as possible. - auto mapVector = std::make_shared( - context.pool(), - outputType, - nulls, - rows.end(), - inputArray->offsets(), - sizes, - wrappedKeys, - wrappedValues); - - checkDuplicateKeys(mapVector, *remianingRows, context); - return mapVector; - } -}; -} // namespace - -VELOX_DECLARE_VECTOR_FUNCTION( - udf_map_from_entries, - MapFromEntriesFunction::signatures(), - std::make_unique()); -} // namespace facebook::velox::functions diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index c8e13cf08911..0b5772905c99 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -19,6 +19,7 @@ #include "velox/expression/RowConstructor.h" #include "velox/expression/SpecialFormRegistry.h" #include "velox/functions/lib/IsNull.h" +#include "velox/functions/lib/MapFromEntries.h" #include "velox/functions/lib/Re2Functions.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/lib/Repeat.h" @@ -72,8 +73,7 @@ static void workAroundRegistrationMacro(const std::string& prefix) { VELOX_REGISTER_VECTOR_FUNCTION( udf_map_allow_duplicates, prefix + "map_from_arrays"); - VELOX_REGISTER_VECTOR_FUNCTION( - udf_map_from_entries, prefix + "map_from_entries"); + registerMapFromEntriesAllowNullEleFunction(prefix + "map_from_entries"); VELOX_REGISTER_VECTOR_FUNCTION( udf_concat_row, exec::RowConstructorCallToSpecialForm::kRowConstructor); // String functions.