diff --git a/velox/functions/lib/ArrayIntersectExcept.h b/velox/functions/lib/ArrayIntersectExcept.h index 7ec5fc3bea40f..311afabbd7941 100644 --- a/velox/functions/lib/ArrayIntersectExcept.h +++ b/velox/functions/lib/ArrayIntersectExcept.h @@ -32,13 +32,14 @@ struct SetWithNull { folly::F14FastSet set; bool hasNull{false}; + bool hasNaN{false}; static constexpr vector_size_t kInitialSetSize{128}; }; // Generates a set based on the elements of an ArrayVector. Note that we take // rightSet as a parameter (instead of returning a new one) to reuse the // allocated memory. -template +template void generateSet( const ArrayVector* arrayVector, const TVector* arrayElements, @@ -54,11 +55,20 @@ void generateSet( } else { // Function can be called with either FlatVector or DecodedVector, but // their APIs are slightly different. + T value; if constexpr (std::is_same_v) { - rightSet.set.insert(arrayElements->template valueAt(i)); + value = arrayElements->template valueAt(i); } else { - rightSet.set.insert(arrayElements->valueAt(i)); + value = arrayElements->valueAt(i); } + if constexpr ( + recordNaN && + (std::is_same_v || std::is_same_v)) { + if (std::isnan(value)) { + rightSet.hasNaN = true; + } + } + rightSet.set.insert(value); } } } @@ -69,7 +79,7 @@ DecodedVector* decodeArrayElements( const SelectivityVector& rows); // See documentation at https://prestodb.io/docs/current/functions/array.html -template +template class ArrayIntersectExceptFunction : public exec::VectorFunction { public: /// This class is used for both array_intersect and array_except functions @@ -154,6 +164,7 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction { auto idx = decodedLeftArray->index(row); auto size = baseLeftArray->sizeAt(idx); auto offset = baseLeftArray->offsetAt(idx); + bool hasNaN = false; outputSet.reset(); rawNewOffsets[row] = indicesCursor; @@ -182,10 +193,28 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction { // for array_except) in the right-hand side, and wasn't added already // (check outputSet). bool addValue = false; + constexpr bool isFloating = + std::is_same_v || std::is_same_v; if constexpr (isIntersect) { - addValue = rightSet.set.count(val) > 0; + if constexpr (equalNaN && isFloating) { + if (rightSet.hasNaN && std::isnan(val)) { + addValue = true; + } else { + addValue = rightSet.set.count(val) > 0; + } + } else { + addValue = rightSet.set.count(val) > 0; + } } else { - addValue = rightSet.set.count(val) == 0; + if constexpr (equalNaN && isFloating) { + if (rightSet.hasNaN && std::isnan(val)) { + addValue = false; + } else { + addValue = rightSet.set.count(val) == 0; + } + } else { + addValue = rightSet.set.count(val) == 0; + } } if (addValue) { auto it = outputSet.set.insert(val); @@ -218,7 +247,8 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction { auto rightArrayVector = rightHolder.get()->base()->as(); rows.applyToSelected([&](vector_size_t row) { auto idx = rightHolder.get()->index(row); - generateSet(rightArrayVector, decodedRightElements, idx, rightSet); + generateSet( + rightArrayVector, decodedRightElements, idx, rightSet); processRow(row, rightSet, outputSet); }); } diff --git a/velox/functions/prestosql/ArrayIntersectExcept.cpp b/velox/functions/prestosql/ArrayIntersectExcept.cpp index 9b020f3ac1cdb..d2dce2cbf8626 100644 --- a/velox/functions/prestosql/ArrayIntersectExcept.cpp +++ b/velox/functions/prestosql/ArrayIntersectExcept.cpp @@ -173,10 +173,12 @@ std::shared_ptr createTypedArraysIntersectExcept( // // If rhs is a constant value: if (rhs != nullptr) { - return std::make_shared>( + return std::make_shared< + ArrayIntersectExceptFunction>( validateConstantVectorAndGenerateSet(rhs)); } else { - return std::make_shared>(); + return std::make_shared< + ArrayIntersectExceptFunction>(); } }