From 6ffd2086dc260a4c299c5707c41f8aa8be0ef1c3 Mon Sep 17 00:00:00 2001 From: rui-mo Date: Tue, 27 Feb 2024 16:31:31 +0800 Subject: [PATCH] refactor --- velox/functions/lib/ArrayIntersectExcept.cpp | 35 +++ velox/functions/lib/ArrayIntersectExcept.h | 244 ++++++++++++++++++ velox/functions/lib/CMakeLists.txt | 1 + .../prestosql/ArrayIntersectExcept.cpp | 237 +---------------- 4 files changed, 281 insertions(+), 236 deletions(-) create mode 100644 velox/functions/lib/ArrayIntersectExcept.cpp create mode 100644 velox/functions/lib/ArrayIntersectExcept.h diff --git a/velox/functions/lib/ArrayIntersectExcept.cpp b/velox/functions/lib/ArrayIntersectExcept.cpp new file mode 100644 index 0000000000000..934da33d03593 --- /dev/null +++ b/velox/functions/lib/ArrayIntersectExcept.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/lib/ArrayIntersectExcept.h" + +namespace facebook::velox::functions { + +DecodedVector* decodeArrayElements( + exec::LocalDecodedVector& arrayDecoder, + exec::LocalDecodedVector& elementsDecoder, + const SelectivityVector& rows) { + auto decodedVector = arrayDecoder.get(); + auto baseArrayVector = arrayDecoder->base()->as(); + + // Decode and acquire array elements vector. + auto elementsVector = baseArrayVector->elements(); + auto elementsSelectivityRows = toElementRows( + elementsVector->size(), rows, baseArrayVector, decodedVector->indices()); + elementsDecoder.get()->decode(*elementsVector, elementsSelectivityRows); + auto decodedElementsVector = elementsDecoder.get(); + return decodedElementsVector; +} +} // namespace facebook::velox::functions diff --git a/velox/functions/lib/ArrayIntersectExcept.h b/velox/functions/lib/ArrayIntersectExcept.h new file mode 100644 index 0000000000000..7ec5fc3bea40f --- /dev/null +++ b/velox/functions/lib/ArrayIntersectExcept.h @@ -0,0 +1,244 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/expression/VectorFunction.h" +#include "velox/functions/lib/LambdaFunctionUtil.h" +#include "velox/functions/lib/RowsTranslationUtil.h" + +namespace facebook::velox::functions { + +template +struct SetWithNull { + SetWithNull(vector_size_t initialSetSize = kInitialSetSize) { + set.reserve(initialSetSize); + } + + void reset() { + set.clear(); + hasNull = false; + } + + folly::F14FastSet set; + bool hasNull{false}; + static constexpr vector_size_t kInitialSetSize{128}; +}; + +// Generates a set based on the elements of an ArrayVector. Note that we take +// rightSet as a parameter (instead of returning a new one) to reuse the +// allocated memory. +template +void generateSet( + const ArrayVector* arrayVector, + const TVector* arrayElements, + vector_size_t idx, + SetWithNull& rightSet) { + auto size = arrayVector->sizeAt(idx); + auto offset = arrayVector->offsetAt(idx); + rightSet.reset(); + + for (vector_size_t i = offset; i < (offset + size); ++i) { + if (arrayElements->isNullAt(i)) { + rightSet.hasNull = true; + } else { + // Function can be called with either FlatVector or DecodedVector, but + // their APIs are slightly different. + if constexpr (std::is_same_v) { + rightSet.set.insert(arrayElements->template valueAt(i)); + } else { + rightSet.set.insert(arrayElements->valueAt(i)); + } + } + } +} + +DecodedVector* decodeArrayElements( + exec::LocalDecodedVector& arrayDecoder, + exec::LocalDecodedVector& elementsDecoder, + const SelectivityVector& rows); + +// See documentation at https://prestodb.io/docs/current/functions/array.html +template +class ArrayIntersectExceptFunction : public exec::VectorFunction { + public: + /// This class is used for both array_intersect and array_except functions + /// (behavior controlled at compile time by the isIntersect template + /// variable). Both these functions take two ArrayVectors as inputs (left and + /// right) and leverage two sets to calculate the intersection (or except): + /// + /// - rightSet: a set that contains all (distinct) elements from the + /// right-hand side array. + /// - outputSet: a set that contains the elements that were already added to + /// the output (to prevent duplicates). + /// + /// Along with each set, we maintain a `hasNull` flag that indicates whether + /// null is present in the arrays, to prevent the use of optional types or + /// special values. + /// + /// Zero element copy: + /// + /// In order to prevent copies of array elements, the function reuses the + /// internal elements() vector from the left-hand side ArrayVector. + /// + /// First a new vector is created containing the indices of the elements + /// which will be present in the output, and wrapped into a DictionaryVector. + /// Next the `lengths` and `offsets` vectors that control where output arrays + /// start and end are wrapped into the output ArrayVector. + /// + /// Constant optimization: + /// + /// If the rhs values passed to either array_intersect() or array_except() + /// are constant (array literals) we create a set before instantiating the + /// object and pass as a constructor parameter (constantSet). + + ArrayIntersectExceptFunction() = default; + + explicit ArrayIntersectExceptFunction(SetWithNull constantSet) + : constantSet_(std::move(constantSet)) {} + + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const override { + memory::MemoryPool* pool = context.pool(); + BaseVector* left = args[0].get(); + BaseVector* right = args[1].get(); + + exec::LocalDecodedVector leftHolder(context, *left, rows); + auto decodedLeftArray = leftHolder.get(); + auto baseLeftArray = decodedLeftArray->base()->as(); + + // Decode and acquire array elements vector. + exec::LocalDecodedVector leftElementsDecoder(context); + auto decodedLeftElements = + decodeArrayElements(leftHolder, leftElementsDecoder, rows); + + auto leftElementsCount = + countElements(rows, *decodedLeftArray); + vector_size_t rowCount = left->size(); + + // Allocate new vectors for indices, nulls, length and offsets. + BufferPtr newIndices = allocateIndices(leftElementsCount, pool); + BufferPtr newElementNulls = + AlignedBuffer::allocate(leftElementsCount, pool, bits::kNotNull); + BufferPtr newLengths = allocateSizes(rowCount, pool); + BufferPtr newOffsets = allocateOffsets(rowCount, pool); + + // Pointers and cursors to the raw data. + auto rawNewIndices = newIndices->asMutable(); + auto rawNewElementNulls = newElementNulls->asMutable(); + auto rawNewLengths = newLengths->asMutable(); + auto rawNewOffsets = newOffsets->asMutable(); + + vector_size_t indicesCursor = 0; + + // Lambda that process each row. This is detached from the code so we can + // apply it differently based on whether the right-hand side set is constant + // or not. + auto processRow = [&](vector_size_t row, + const SetWithNull& rightSet, + SetWithNull& outputSet) { + auto idx = decodedLeftArray->index(row); + auto size = baseLeftArray->sizeAt(idx); + auto offset = baseLeftArray->offsetAt(idx); + + outputSet.reset(); + rawNewOffsets[row] = indicesCursor; + + // Scans the array elements on the left-hand side. + for (vector_size_t i = offset; i < (offset + size); ++i) { + if (decodedLeftElements->isNullAt(i)) { + // For a NULL value not added to the output row yet, insert in + // array_intersect if it was found on the rhs (and not found in the + // case of array_except). + if (!outputSet.hasNull) { + bool setNull = false; + if constexpr (isIntersect) { + setNull = rightSet.hasNull; + } else { + setNull = !rightSet.hasNull; + } + if (setNull) { + bits::setNull(rawNewElementNulls, indicesCursor++, true); + outputSet.hasNull = true; + } + } + } else { + auto val = decodedLeftElements->valueAt(i); + // For array_intersect, add the element if it is found (not found + // for array_except) in the right-hand side, and wasn't added already + // (check outputSet). + bool addValue = false; + if constexpr (isIntersect) { + addValue = rightSet.set.count(val) > 0; + } else { + addValue = rightSet.set.count(val) == 0; + } + if (addValue) { + auto it = outputSet.set.insert(val); + if (it.second) { + rawNewIndices[indicesCursor++] = i; + } + } + } + } + rawNewLengths[row] = indicesCursor - rawNewOffsets[row]; + }; + + SetWithNull outputSet; + + // Optimized case when the right-hand side array is constant. + if (constantSet_.has_value()) { + rows.applyToSelected([&](vector_size_t row) { + processRow(row, *constantSet_, outputSet); + }); + } + // General case when no arrays are constant and both sets need to be + // computed for each row. + else { + exec::LocalDecodedVector rightHolder(context, *right, rows); + // Decode and acquire array elements vector. + exec::LocalDecodedVector rightElementsHolder(context); + auto decodedRightElements = + decodeArrayElements(rightHolder, rightElementsHolder, rows); + SetWithNull rightSet; + auto rightArrayVector = rightHolder.get()->base()->as(); + rows.applyToSelected([&](vector_size_t row) { + auto idx = rightHolder.get()->index(row); + generateSet(rightArrayVector, decodedRightElements, idx, rightSet); + processRow(row, rightSet, outputSet); + }); + } + + auto newElements = BaseVector::wrapInDictionary( + newElementNulls, newIndices, indicesCursor, baseLeftArray->elements()); + auto resultArray = std::make_shared( + pool, + outputType, + nullptr, + rowCount, + newOffsets, + newLengths, + newElements); + context.moveOrCopyResult(resultArray, rows, result); + } + + // If one of the arrays is constant, this member will store a pointer to the + // set generated from its elements, which is calculated only once, before + // instantiating this object. + std::optional> constantSet_; +}; // class ArrayIntersectExcept +} // namespace facebook::velox::functions diff --git a/velox/functions/lib/CMakeLists.txt b/velox/functions/lib/CMakeLists.txt index fe1bdb405ff9a..fdaa33bda7156 100644 --- a/velox/functions/lib/CMakeLists.txt +++ b/velox/functions/lib/CMakeLists.txt @@ -22,6 +22,7 @@ target_link_libraries(velox_functions_util velox_vector velox_common_base) add_library( velox_functions_lib + ArrayIntersectExcept.cpp CheckDuplicateKeys.cpp DateTimeFormatter.cpp DateTimeFormatterBuilder.cpp diff --git a/velox/functions/prestosql/ArrayIntersectExcept.cpp b/velox/functions/prestosql/ArrayIntersectExcept.cpp index 1336ff2a0fe3b..9b020f3ac1cdb 100644 --- a/velox/functions/prestosql/ArrayIntersectExcept.cpp +++ b/velox/functions/prestosql/ArrayIntersectExcept.cpp @@ -13,245 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "velox/expression/VectorFunction.h" -#include "velox/functions/lib/LambdaFunctionUtil.h" -#include "velox/functions/lib/RowsTranslationUtil.h" +#include "velox/functions/lib/ArrayIntersectExcept.h" namespace facebook::velox::functions { namespace { -template - -struct SetWithNull { - SetWithNull(vector_size_t initialSetSize = kInitialSetSize) { - set.reserve(initialSetSize); - } - - void reset() { - set.clear(); - hasNull = false; - } - - folly::F14FastSet set; - bool hasNull{false}; - static constexpr vector_size_t kInitialSetSize{128}; -}; -// Generates a set based on the elements of an ArrayVector. Note that we take -// rightSet as a parameter (instead of returning a new one) to reuse the -// allocated memory. -template -void generateSet( - const ArrayVector* arrayVector, - const TVector* arrayElements, - vector_size_t idx, - SetWithNull& rightSet) { - auto size = arrayVector->sizeAt(idx); - auto offset = arrayVector->offsetAt(idx); - rightSet.reset(); - - for (vector_size_t i = offset; i < (offset + size); ++i) { - if (arrayElements->isNullAt(i)) { - rightSet.hasNull = true; - } else { - // Function can be called with either FlatVector or DecodedVector, but - // their APIs are slightly different. - if constexpr (std::is_same_v) { - rightSet.set.insert(arrayElements->template valueAt(i)); - } else { - rightSet.set.insert(arrayElements->valueAt(i)); - } - } - } -} - -DecodedVector* decodeArrayElements( - exec::LocalDecodedVector& arrayDecoder, - exec::LocalDecodedVector& elementsDecoder, - const SelectivityVector& rows) { - auto decodedVector = arrayDecoder.get(); - auto baseArrayVector = arrayDecoder->base()->as(); - - // Decode and acquire array elements vector. - auto elementsVector = baseArrayVector->elements(); - auto elementsSelectivityRows = toElementRows( - elementsVector->size(), rows, baseArrayVector, decodedVector->indices()); - elementsDecoder.get()->decode(*elementsVector, elementsSelectivityRows); - auto decodedElementsVector = elementsDecoder.get(); - return decodedElementsVector; -} - -// See documentation at https://prestodb.io/docs/current/functions/array.html -template -class ArrayIntersectExceptFunction : public exec::VectorFunction { - public: - /// This class is used for both array_intersect and array_except functions - /// (behavior controlled at compile time by the isIntersect template - /// variable). Both these functions take two ArrayVectors as inputs (left and - /// right) and leverage two sets to calculate the intersection (or except): - /// - /// - rightSet: a set that contains all (distinct) elements from the - /// right-hand side array. - /// - outputSet: a set that contains the elements that were already added to - /// the output (to prevent duplicates). - /// - /// Along with each set, we maintain a `hasNull` flag that indicates whether - /// null is present in the arrays, to prevent the use of optional types or - /// special values. - /// - /// Zero element copy: - /// - /// In order to prevent copies of array elements, the function reuses the - /// internal elements() vector from the left-hand side ArrayVector. - /// - /// First a new vector is created containing the indices of the elements - /// which will be present in the output, and wrapped into a DictionaryVector. - /// Next the `lengths` and `offsets` vectors that control where output arrays - /// start and end are wrapped into the output ArrayVector. - /// - /// Constant optimization: - /// - /// If the rhs values passed to either array_intersect() or array_except() - /// are constant (array literals) we create a set before instantiating the - /// object and pass as a constructor parameter (constantSet). - - ArrayIntersectExceptFunction() = default; - - explicit ArrayIntersectExceptFunction(SetWithNull constantSet) - : constantSet_(std::move(constantSet)) {} - - void apply( - const SelectivityVector& rows, - std::vector& args, - const TypePtr& outputType, - exec::EvalCtx& context, - VectorPtr& result) const override { - memory::MemoryPool* pool = context.pool(); - BaseVector* left = args[0].get(); - BaseVector* right = args[1].get(); - - exec::LocalDecodedVector leftHolder(context, *left, rows); - auto decodedLeftArray = leftHolder.get(); - auto baseLeftArray = decodedLeftArray->base()->as(); - - // Decode and acquire array elements vector. - exec::LocalDecodedVector leftElementsDecoder(context); - auto decodedLeftElements = - decodeArrayElements(leftHolder, leftElementsDecoder, rows); - - auto leftElementsCount = - countElements(rows, *decodedLeftArray); - vector_size_t rowCount = left->size(); - - // Allocate new vectors for indices, nulls, length and offsets. - BufferPtr newIndices = allocateIndices(leftElementsCount, pool); - BufferPtr newElementNulls = - AlignedBuffer::allocate(leftElementsCount, pool, bits::kNotNull); - BufferPtr newLengths = allocateSizes(rowCount, pool); - BufferPtr newOffsets = allocateOffsets(rowCount, pool); - - // Pointers and cursors to the raw data. - auto rawNewIndices = newIndices->asMutable(); - auto rawNewElementNulls = newElementNulls->asMutable(); - auto rawNewLengths = newLengths->asMutable(); - auto rawNewOffsets = newOffsets->asMutable(); - - vector_size_t indicesCursor = 0; - - // Lambda that process each row. This is detached from the code so we can - // apply it differently based on whether the right-hand side set is constant - // or not. - auto processRow = [&](vector_size_t row, - const SetWithNull& rightSet, - SetWithNull& outputSet) { - auto idx = decodedLeftArray->index(row); - auto size = baseLeftArray->sizeAt(idx); - auto offset = baseLeftArray->offsetAt(idx); - - outputSet.reset(); - rawNewOffsets[row] = indicesCursor; - - // Scans the array elements on the left-hand side. - for (vector_size_t i = offset; i < (offset + size); ++i) { - if (decodedLeftElements->isNullAt(i)) { - // For a NULL value not added to the output row yet, insert in - // array_intersect if it was found on the rhs (and not found in the - // case of array_except). - if (!outputSet.hasNull) { - bool setNull = false; - if constexpr (isIntersect) { - setNull = rightSet.hasNull; - } else { - setNull = !rightSet.hasNull; - } - if (setNull) { - bits::setNull(rawNewElementNulls, indicesCursor++, true); - outputSet.hasNull = true; - } - } - } else { - auto val = decodedLeftElements->valueAt(i); - // For array_intersect, add the element if it is found (not found - // for array_except) in the right-hand side, and wasn't added already - // (check outputSet). - bool addValue = false; - if constexpr (isIntersect) { - addValue = rightSet.set.count(val) > 0; - } else { - addValue = rightSet.set.count(val) == 0; - } - if (addValue) { - auto it = outputSet.set.insert(val); - if (it.second) { - rawNewIndices[indicesCursor++] = i; - } - } - } - } - rawNewLengths[row] = indicesCursor - rawNewOffsets[row]; - }; - - SetWithNull outputSet; - - // Optimized case when the right-hand side array is constant. - if (constantSet_.has_value()) { - rows.applyToSelected([&](vector_size_t row) { - processRow(row, *constantSet_, outputSet); - }); - } - // General case when no arrays are constant and both sets need to be - // computed for each row. - else { - exec::LocalDecodedVector rightHolder(context, *right, rows); - // Decode and acquire array elements vector. - exec::LocalDecodedVector rightElementsHolder(context); - auto decodedRightElements = - decodeArrayElements(rightHolder, rightElementsHolder, rows); - SetWithNull rightSet; - auto rightArrayVector = rightHolder.get()->base()->as(); - rows.applyToSelected([&](vector_size_t row) { - auto idx = rightHolder.get()->index(row); - generateSet(rightArrayVector, decodedRightElements, idx, rightSet); - processRow(row, rightSet, outputSet); - }); - } - - auto newElements = BaseVector::wrapInDictionary( - newElementNulls, newIndices, indicesCursor, baseLeftArray->elements()); - auto resultArray = std::make_shared( - pool, - outputType, - nullptr, - rowCount, - newOffsets, - newLengths, - newElements); - context.moveOrCopyResult(resultArray, rows, result); - } - - // If one of the arrays is constant, this member will store a pointer to the - // set generated from its elements, which is calculated only once, before - // instantiating this object. - std::optional> constantSet_; -}; // class ArrayIntersectExcept template class ArraysOverlapFunction : public exec::VectorFunction {