Skip to content

Commit

Permalink
Add flatten Spark function (#9593)
Browse files Browse the repository at this point in the history
Summary:
In Presto, `flatten` ignores NULL array element in the input.
In Spark, `flatten` returns NULL if any array element of the input is NULL.

Spark 3.5 ref: [Flatten function](https://github.com/apache/spark/blob/v3.5.1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L2796)

Pull Request resolved: #9593

Reviewed By: pedroerp

Differential Revision: D56580116

Pulled By: kagamiori

fbshipit-source-id: 535f8112e3d37921f46ee41a8a05401f76c092cf
  • Loading branch information
ivoson authored and facebook-github-bot committed Apr 26, 2024
1 parent 1daeb9d commit 32289f9
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 0 deletions.
9 changes: 9 additions & 0 deletions velox/docs/functions/spark/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ Array Functions
SELECT filter(array(0, 2, 3), (x, i) -> x > i); -- [2, 3]
SELECT filter(array(0, null, 2, 3, null), x -> x IS NOT NULL); -- [0, 2, 3]

.. function:: flatten(array(array(E))) -> array(E)

Transforms an array of arrays into a single array.
Returns NULL if the input is NULL or any of the nested arrays is NULL. ::

SELECT flatten(array(array(1, 2), array(3, 4))); -- [1, 2, 3, 4]
SELECT flatten(array(array(1, 2), array(3, NULL))); -- [1, 2, 3, NULL]
SELECT flatten(array(array(1, 2), NULL, array(3, 4))); -- NULL

.. spark:function:: in(value, array(E)) -> boolean
Returns true if value matches at least one of the elements of the array.
Expand Down
57 changes: 57 additions & 0 deletions velox/functions/sparksql/ArrayFlattenFunction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/functions/Udf.h"

namespace facebook::velox::functions::sparksql {

/// flatten(array(array(E))) → array(E)
/// Flattens nested array by concatenating the contained arrays.
template <typename T>
struct ArrayFlattenFunction {
VELOX_DEFINE_FUNCTION_TYPES(T)

// INT_MAX - 15, keep the same limit with spark.
static constexpr int32_t kMaxNumberOfElements = 2147483632;

FOLLY_ALWAYS_INLINE bool call(
out_type<Array<Generic<T1>>>& out,
const arg_type<Array<Array<Generic<T1>>>>& arrays) {
int64_t elementCount = 0;
for (const auto& array : arrays) {
if (array.has_value()) {
elementCount += array.value().size();
} else {
// Return NULL if any of the nested arrays is NULL.
return false;
}
}

VELOX_USER_CHECK_LE(
elementCount,
kMaxNumberOfElements,
"array flatten result exceeds the max array size limit {}",
kMaxNumberOfElements);

out.reserve(elementCount);
for (const auto& array : arrays) {
out.add_items(array.value());
};
return true;
}
};
} // namespace facebook::velox::functions::sparksql
6 changes: 6 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "velox/functions/prestosql/ArrayFunctions.h"
#include "velox/functions/prestosql/DateTimeFunctions.h"
#include "velox/functions/prestosql/StringFunctions.h"
#include "velox/functions/sparksql/ArrayFlattenFunction.h"
#include "velox/functions/sparksql/ArrayMinMaxFunction.h"
#include "velox/functions/sparksql/ArraySizeFunction.h"
#include "velox/functions/sparksql/ArraySort.h"
Expand Down Expand Up @@ -407,6 +408,11 @@ void registerFunctions(const std::string& prefix) {
{prefix + "monotonically_increasing_id"});

registerFunction<UuidFunction, Varchar, Constant<int64_t>>({prefix + "uuid"});

registerFunction<
ArrayFlattenFunction,
Array<Generic<T1>>,
Array<Array<Generic<T1>>>>({prefix + "flatten"});
}

} // namespace sparksql
Expand Down
62 changes: 62 additions & 0 deletions velox/functions/sparksql/tests/ArrayFlattenTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

using namespace facebook::velox::test;

namespace facebook::velox::functions::sparksql::test {
namespace {

class ArrayFlattenTest : public SparkFunctionBaseTest {
protected:
void testExpression(
const std::string& expression,
const std::vector<VectorPtr>& input,
const VectorPtr& expected) {
const auto result = evaluate(expression, makeRowVector(input));
assertEqualVectors(expected, result);
}
};

// Flatten integer arrays.
TEST_F(ArrayFlattenTest, intArrays) {
const auto arrayOfArrays = makeNestedArrayVectorFromJson<int64_t>({
"[[1, 1], [2, 2], [3, 3]]",
"[[4, 4]]",
"[[5, 5], [6, 6]]",
});

const auto expected = makeArrayVectorFromJson<int64_t>(
{"[1, 1, 2, 2, 3, 3]", "[4, 4]", "[5, 5, 6, 6]"});

testExpression("flatten(c0)", {arrayOfArrays}, expected);
}

// Flatten arrays with null.
TEST_F(ArrayFlattenTest, nullArray) {
const auto arrayOfArrays = makeNestedArrayVectorFromJson<int64_t>({
"[[1, 1], null, [3, 3]]",
"null",
"[[5, null], [null, 6], [null, null], []]",
});

const auto expected = makeArrayVectorFromJson<int64_t>(
{"null", "null", "[5, null, null, 6, null, null]"});

testExpression("flatten(c0)", {arrayOfArrays}, expected);
}
} // namespace
} // namespace facebook::velox::functions::sparksql::test
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
add_executable(
velox_functions_spark_test
ArithmeticTest.cpp
ArrayFlattenTest.cpp
ArrayMaxTest.cpp
ArrayMinTest.cpp
ArraySizeTest.cpp
Expand Down

0 comments on commit 32289f9

Please sign in to comment.