diff --git a/velox/docs/functions/spark/array.rst b/velox/docs/functions/spark/array.rst index 705dde524c4e..aa5932e16732 100644 --- a/velox/docs/functions/spark/array.rst +++ b/velox/docs/functions/spark/array.rst @@ -106,6 +106,15 @@ Array Functions SELECT filter(array(0, 2, 3), (x, i) -> x > i); -- [2, 3] SELECT filter(array(0, null, 2, 3, null), x -> x IS NOT NULL); -- [0, 2, 3] +.. function:: flatten(array(array(E))) -> array(E) + + Transforms an array of arrays into a single array. + Returns NULL if the input is NULL or any of the nested arrays is NULL. :: + + SELECT flatten(array(array(1, 2), array(3, 4))); -- [1, 2, 3, 4] + SELECT flatten(array(array(1, 2), array(3, NULL))); -- [1, 2, 3, NULL] + SELECT flatten(array(array(1, 2), NULL, array(3, 4))); -- NULL + .. spark:function:: in(value, array(E)) -> boolean Returns true if value matches at least one of the elements of the array. diff --git a/velox/functions/sparksql/ArrayFlattenFunction.h b/velox/functions/sparksql/ArrayFlattenFunction.h new file mode 100644 index 000000000000..ac369714de25 --- /dev/null +++ b/velox/functions/sparksql/ArrayFlattenFunction.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/functions/Udf.h" + +namespace facebook::velox::functions::sparksql { + +/// flatten(array(array(E))) → array(E) +/// Flattens nested array by concatenating the contained arrays. +template +struct ArrayFlattenFunction { + VELOX_DEFINE_FUNCTION_TYPES(T) + + // INT_MAX - 15, keep the same limit with spark. + static constexpr int32_t kMaxNumberOfElements = 2147483632; + + FOLLY_ALWAYS_INLINE bool call( + out_type>>& out, + const arg_type>>>& arrays) { + int64_t elementCount = 0; + for (const auto& array : arrays) { + if (array.has_value()) { + elementCount += array.value().size(); + } else { + // Return NULL if any of the nested arrays is NULL. + return false; + } + } + + VELOX_USER_CHECK_LE( + elementCount, + kMaxNumberOfElements, + "array flatten result exceeds the max array size limit {}", + kMaxNumberOfElements); + + out.reserve(elementCount); + for (const auto& array : arrays) { + out.add_items(array.value()); + }; + return true; + } +}; +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 8b8f93a8d373..81f42fbced02 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -25,6 +25,7 @@ #include "velox/functions/prestosql/ArrayFunctions.h" #include "velox/functions/prestosql/DateTimeFunctions.h" #include "velox/functions/prestosql/StringFunctions.h" +#include "velox/functions/sparksql/ArrayFlattenFunction.h" #include "velox/functions/sparksql/ArrayMinMaxFunction.h" #include "velox/functions/sparksql/ArraySizeFunction.h" #include "velox/functions/sparksql/ArraySort.h" @@ -407,6 +408,11 @@ void registerFunctions(const std::string& prefix) { {prefix + "monotonically_increasing_id"}); registerFunction>({prefix + "uuid"}); + + registerFunction< + ArrayFlattenFunction, + Array>, + Array>>>({prefix + "flatten"}); } } // namespace sparksql diff --git a/velox/functions/sparksql/tests/ArrayFlattenTest.cpp b/velox/functions/sparksql/tests/ArrayFlattenTest.cpp new file mode 100644 index 000000000000..b531c0b903c7 --- /dev/null +++ b/velox/functions/sparksql/tests/ArrayFlattenTest.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +using namespace facebook::velox::test; + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class ArrayFlattenTest : public SparkFunctionBaseTest { + protected: + void testExpression( + const std::string& expression, + const std::vector& input, + const VectorPtr& expected) { + const auto result = evaluate(expression, makeRowVector(input)); + assertEqualVectors(expected, result); + } +}; + +// Flatten integer arrays. +TEST_F(ArrayFlattenTest, intArrays) { + const auto arrayOfArrays = makeNestedArrayVectorFromJson({ + "[[1, 1], [2, 2], [3, 3]]", + "[[4, 4]]", + "[[5, 5], [6, 6]]", + }); + + const auto expected = makeArrayVectorFromJson( + {"[1, 1, 2, 2, 3, 3]", "[4, 4]", "[5, 5, 6, 6]"}); + + testExpression("flatten(c0)", {arrayOfArrays}, expected); +} + +// Flatten arrays with null. +TEST_F(ArrayFlattenTest, nullArray) { + const auto arrayOfArrays = makeNestedArrayVectorFromJson({ + "[[1, 1], null, [3, 3]]", + "null", + "[[5, null], [null, 6], [null, null], []]", + }); + + const auto expected = makeArrayVectorFromJson( + {"null", "null", "[5, null, null, 6, null, null]"}); + + testExpression("flatten(c0)", {arrayOfArrays}, expected); +} +} // namespace +} // namespace facebook::velox::functions::sparksql::test diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index 0089957f10d1..a1b00031f6ab 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -15,6 +15,7 @@ add_executable( velox_functions_spark_test ArithmeticTest.cpp + ArrayFlattenTest.cpp ArrayMaxTest.cpp ArrayMinTest.cpp ArraySizeTest.cpp