Skip to content

Commit

Permalink
try
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Feb 27, 2024
1 parent 6ffd208 commit bce216f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
44 changes: 37 additions & 7 deletions velox/functions/lib/ArrayIntersectExcept.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ struct SetWithNull {

folly::F14FastSet<T> set;
bool hasNull{false};
bool hasNaN{false};
static constexpr vector_size_t kInitialSetSize{128};
};

// Generates a set based on the elements of an ArrayVector. Note that we take
// rightSet as a parameter (instead of returning a new one) to reuse the
// allocated memory.
template <typename T, typename TVector>
template <bool recordNaN, typename T, typename TVector>
void generateSet(
const ArrayVector* arrayVector,
const TVector* arrayElements,
Expand All @@ -54,11 +55,20 @@ void generateSet(
} else {
// Function can be called with either FlatVector or DecodedVector, but
// their APIs are slightly different.
T value;
if constexpr (std::is_same_v<TVector, DecodedVector>) {
rightSet.set.insert(arrayElements->template valueAt<T>(i));
value = arrayElements->template valueAt<T>(i);
} else {
rightSet.set.insert(arrayElements->valueAt(i));
value = arrayElements->valueAt(i);
}
if constexpr (
recordNaN &&
(std::is_same_v<T, float> || std::is_same_v<T, double>)) {
if (std::isnan(value)) {
rightSet.hasNaN = true;
}
}
rightSet.set.insert(value);
}
}
}
Expand All @@ -69,7 +79,7 @@ DecodedVector* decodeArrayElements(
const SelectivityVector& rows);

// See documentation at https://prestodb.io/docs/current/functions/array.html
template <bool isIntersect, typename T>
template <bool isIntersect, bool equalNaN, typename T>
class ArrayIntersectExceptFunction : public exec::VectorFunction {
public:
/// This class is used for both array_intersect and array_except functions
Expand Down Expand Up @@ -154,6 +164,7 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction {
auto idx = decodedLeftArray->index(row);
auto size = baseLeftArray->sizeAt(idx);
auto offset = baseLeftArray->offsetAt(idx);
bool hasNaN = false;

outputSet.reset();
rawNewOffsets[row] = indicesCursor;
Expand Down Expand Up @@ -182,10 +193,28 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction {
// for array_except) in the right-hand side, and wasn't added already
// (check outputSet).
bool addValue = false;
constexpr bool isFloating =
std::is_same_v<T, float> || std::is_same_v<T, double>;
if constexpr (isIntersect) {
addValue = rightSet.set.count(val) > 0;
if constexpr (equalNaN && isFloating) {
if (rightSet.hasNaN && std::isnan(val)) {
addValue = true;
} else {
addValue = rightSet.set.count(val) > 0;
}
} else {
addValue = rightSet.set.count(val) > 0;
}
} else {
addValue = rightSet.set.count(val) == 0;
if constexpr (equalNaN && isFloating) {
if (rightSet.hasNaN && std::isnan(val)) {
addValue = false;
} else {
addValue = rightSet.set.count(val) == 0;
}
} else {
addValue = rightSet.set.count(val) == 0;
}
}
if (addValue) {
auto it = outputSet.set.insert(val);
Expand Down Expand Up @@ -218,7 +247,8 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction {
auto rightArrayVector = rightHolder.get()->base()->as<ArrayVector>();
rows.applyToSelected([&](vector_size_t row) {
auto idx = rightHolder.get()->index(row);
generateSet<T>(rightArrayVector, decodedRightElements, idx, rightSet);
generateSet<equalNaN, T>(
rightArrayVector, decodedRightElements, idx, rightSet);
processRow(row, rightSet, outputSet);
});
}
Expand Down
6 changes: 4 additions & 2 deletions velox/functions/prestosql/ArrayIntersectExcept.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,12 @@ std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
//
// If rhs is a constant value:
if (rhs != nullptr) {
return std::make_shared<ArrayIntersectExceptFunction<isIntersect, T>>(
return std::make_shared<
ArrayIntersectExceptFunction<isIntersect, false, T>>(
validateConstantVectorAndGenerateSet<T>(rhs));
} else {
return std::make_shared<ArrayIntersectExceptFunction<isIntersect, T>>();
return std::make_shared<
ArrayIntersectExceptFunction<isIntersect, false, T>>();
}
}

Expand Down

0 comments on commit bce216f

Please sign in to comment.