Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Feb 27, 2024
1 parent 5fa959f commit 6ffd208
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 236 deletions.
35 changes: 35 additions & 0 deletions velox/functions/lib/ArrayIntersectExcept.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/lib/ArrayIntersectExcept.h"

namespace facebook::velox::functions {

DecodedVector* decodeArrayElements(
exec::LocalDecodedVector& arrayDecoder,
exec::LocalDecodedVector& elementsDecoder,
const SelectivityVector& rows) {
auto decodedVector = arrayDecoder.get();
auto baseArrayVector = arrayDecoder->base()->as<ArrayVector>();

// Decode and acquire array elements vector.
auto elementsVector = baseArrayVector->elements();
auto elementsSelectivityRows = toElementRows(
elementsVector->size(), rows, baseArrayVector, decodedVector->indices());
elementsDecoder.get()->decode(*elementsVector, elementsSelectivityRows);
auto decodedElementsVector = elementsDecoder.get();
return decodedElementsVector;
}
} // namespace facebook::velox::functions
244 changes: 244 additions & 0 deletions velox/functions/lib/ArrayIntersectExcept.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/LambdaFunctionUtil.h"
#include "velox/functions/lib/RowsTranslationUtil.h"

namespace facebook::velox::functions {

template <typename T>
struct SetWithNull {
SetWithNull(vector_size_t initialSetSize = kInitialSetSize) {
set.reserve(initialSetSize);
}

void reset() {
set.clear();
hasNull = false;
}

folly::F14FastSet<T> set;
bool hasNull{false};
static constexpr vector_size_t kInitialSetSize{128};
};

// Generates a set based on the elements of an ArrayVector. Note that we take
// rightSet as a parameter (instead of returning a new one) to reuse the
// allocated memory.
template <typename T, typename TVector>
void generateSet(
const ArrayVector* arrayVector,
const TVector* arrayElements,
vector_size_t idx,
SetWithNull<T>& rightSet) {
auto size = arrayVector->sizeAt(idx);
auto offset = arrayVector->offsetAt(idx);
rightSet.reset();

for (vector_size_t i = offset; i < (offset + size); ++i) {
if (arrayElements->isNullAt(i)) {
rightSet.hasNull = true;
} else {
// Function can be called with either FlatVector or DecodedVector, but
// their APIs are slightly different.
if constexpr (std::is_same_v<TVector, DecodedVector>) {
rightSet.set.insert(arrayElements->template valueAt<T>(i));
} else {
rightSet.set.insert(arrayElements->valueAt(i));
}
}
}
}

DecodedVector* decodeArrayElements(
exec::LocalDecodedVector& arrayDecoder,
exec::LocalDecodedVector& elementsDecoder,
const SelectivityVector& rows);

// See documentation at https://prestodb.io/docs/current/functions/array.html
template <bool isIntersect, typename T>
class ArrayIntersectExceptFunction : public exec::VectorFunction {
public:
/// This class is used for both array_intersect and array_except functions
/// (behavior controlled at compile time by the isIntersect template
/// variable). Both these functions take two ArrayVectors as inputs (left and
/// right) and leverage two sets to calculate the intersection (or except):
///
/// - rightSet: a set that contains all (distinct) elements from the
/// right-hand side array.
/// - outputSet: a set that contains the elements that were already added to
/// the output (to prevent duplicates).
///
/// Along with each set, we maintain a `hasNull` flag that indicates whether
/// null is present in the arrays, to prevent the use of optional types or
/// special values.
///
/// Zero element copy:
///
/// In order to prevent copies of array elements, the function reuses the
/// internal elements() vector from the left-hand side ArrayVector.
///
/// First a new vector is created containing the indices of the elements
/// which will be present in the output, and wrapped into a DictionaryVector.
/// Next the `lengths` and `offsets` vectors that control where output arrays
/// start and end are wrapped into the output ArrayVector.
///
/// Constant optimization:
///
/// If the rhs values passed to either array_intersect() or array_except()
/// are constant (array literals) we create a set before instantiating the
/// object and pass as a constructor parameter (constantSet).

ArrayIntersectExceptFunction() = default;

explicit ArrayIntersectExceptFunction(SetWithNull<T> constantSet)
: constantSet_(std::move(constantSet)) {}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
memory::MemoryPool* pool = context.pool();
BaseVector* left = args[0].get();
BaseVector* right = args[1].get();

exec::LocalDecodedVector leftHolder(context, *left, rows);
auto decodedLeftArray = leftHolder.get();
auto baseLeftArray = decodedLeftArray->base()->as<ArrayVector>();

// Decode and acquire array elements vector.
exec::LocalDecodedVector leftElementsDecoder(context);
auto decodedLeftElements =
decodeArrayElements(leftHolder, leftElementsDecoder, rows);

auto leftElementsCount =
countElements<ArrayVector>(rows, *decodedLeftArray);
vector_size_t rowCount = left->size();

// Allocate new vectors for indices, nulls, length and offsets.
BufferPtr newIndices = allocateIndices(leftElementsCount, pool);
BufferPtr newElementNulls =
AlignedBuffer::allocate<bool>(leftElementsCount, pool, bits::kNotNull);
BufferPtr newLengths = allocateSizes(rowCount, pool);
BufferPtr newOffsets = allocateOffsets(rowCount, pool);

// Pointers and cursors to the raw data.
auto rawNewIndices = newIndices->asMutable<vector_size_t>();
auto rawNewElementNulls = newElementNulls->asMutable<uint64_t>();
auto rawNewLengths = newLengths->asMutable<vector_size_t>();
auto rawNewOffsets = newOffsets->asMutable<vector_size_t>();

vector_size_t indicesCursor = 0;

// Lambda that process each row. This is detached from the code so we can
// apply it differently based on whether the right-hand side set is constant
// or not.
auto processRow = [&](vector_size_t row,
const SetWithNull<T>& rightSet,
SetWithNull<T>& outputSet) {
auto idx = decodedLeftArray->index(row);
auto size = baseLeftArray->sizeAt(idx);
auto offset = baseLeftArray->offsetAt(idx);

outputSet.reset();
rawNewOffsets[row] = indicesCursor;

// Scans the array elements on the left-hand side.
for (vector_size_t i = offset; i < (offset + size); ++i) {
if (decodedLeftElements->isNullAt(i)) {
// For a NULL value not added to the output row yet, insert in
// array_intersect if it was found on the rhs (and not found in the
// case of array_except).
if (!outputSet.hasNull) {
bool setNull = false;
if constexpr (isIntersect) {
setNull = rightSet.hasNull;
} else {
setNull = !rightSet.hasNull;
}
if (setNull) {
bits::setNull(rawNewElementNulls, indicesCursor++, true);
outputSet.hasNull = true;
}
}
} else {
auto val = decodedLeftElements->valueAt<T>(i);
// For array_intersect, add the element if it is found (not found
// for array_except) in the right-hand side, and wasn't added already
// (check outputSet).
bool addValue = false;
if constexpr (isIntersect) {
addValue = rightSet.set.count(val) > 0;
} else {
addValue = rightSet.set.count(val) == 0;
}
if (addValue) {
auto it = outputSet.set.insert(val);
if (it.second) {
rawNewIndices[indicesCursor++] = i;
}
}
}
}
rawNewLengths[row] = indicesCursor - rawNewOffsets[row];
};

SetWithNull<T> outputSet;

// Optimized case when the right-hand side array is constant.
if (constantSet_.has_value()) {
rows.applyToSelected([&](vector_size_t row) {
processRow(row, *constantSet_, outputSet);
});
}
// General case when no arrays are constant and both sets need to be
// computed for each row.
else {
exec::LocalDecodedVector rightHolder(context, *right, rows);
// Decode and acquire array elements vector.
exec::LocalDecodedVector rightElementsHolder(context);
auto decodedRightElements =
decodeArrayElements(rightHolder, rightElementsHolder, rows);
SetWithNull<T> rightSet;
auto rightArrayVector = rightHolder.get()->base()->as<ArrayVector>();
rows.applyToSelected([&](vector_size_t row) {
auto idx = rightHolder.get()->index(row);
generateSet<T>(rightArrayVector, decodedRightElements, idx, rightSet);
processRow(row, rightSet, outputSet);
});
}

auto newElements = BaseVector::wrapInDictionary(
newElementNulls, newIndices, indicesCursor, baseLeftArray->elements());
auto resultArray = std::make_shared<ArrayVector>(
pool,
outputType,
nullptr,
rowCount,
newOffsets,
newLengths,
newElements);
context.moveOrCopyResult(resultArray, rows, result);
}

// If one of the arrays is constant, this member will store a pointer to the
// set generated from its elements, which is calculated only once, before
// instantiating this object.
std::optional<SetWithNull<T>> constantSet_;
}; // class ArrayIntersectExcept
} // namespace facebook::velox::functions
1 change: 1 addition & 0 deletions velox/functions/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ target_link_libraries(velox_functions_util velox_vector velox_common_base)

add_library(
velox_functions_lib
ArrayIntersectExcept.cpp
CheckDuplicateKeys.cpp
DateTimeFormatter.cpp
DateTimeFormatterBuilder.cpp
Expand Down
Loading

0 comments on commit 6ffd208

Please sign in to comment.