diff --git a/velox/functions/prestosql/ArrayDistinct.cpp b/velox/functions/prestosql/ArrayDistinct.cpp index 0ac831285f45..15a84087d689 100644 --- a/velox/functions/prestosql/ArrayDistinct.cpp +++ b/velox/functions/prestosql/ArrayDistinct.cpp @@ -38,31 +38,35 @@ struct ValueSet { template <> struct ValueSet { - using TKey = std::tuple; + struct Key { + const uint64_t hash; + const BaseVector* vector; + const vector_size_t index; + }; struct Hash { - size_t operator()(const TKey& key) const { - return std::get<0>(key); + size_t operator()(const Key& key) const { + return key.hash; } }; struct EqualTo { - bool operator()(const TKey& left, const TKey& right) const { - return std::get<1>(left) + bool operator()(const Key& left, const Key& right) const { + return left.vector ->equalValueAt( - std::get<1>(right), - std::get<2>(left), - std::get<2>(right), + right.vector, + left.index, + right.index, CompareFlags::NullHandlingMode::kNullAsValue) .value(); } }; - folly::F14FastSet values; + folly::F14FastSet values; bool insert(const BaseVector* vector, vector_size_t index) { const uint64_t hash = vector->hashValueAt(index); - return values.insert(std::make_tuple(hash, vector, index)).second; + return values.insert({hash, vector, index}).second; } void reset() { @@ -86,7 +90,7 @@ struct ValueSet { /// which will be present in the output, and wrapped into a DictionaryVector. /// Next the `lengths` and `offsets` vectors that control where output arrays /// start and end are wrapped into the output ArrayVector.template -template +template class ArrayDistinctFunction : public exec::VectorFunction { public: void apply( @@ -117,6 +121,12 @@ class ArrayDistinctFunction : public exec::VectorFunction { } private: + // We want to use ValueSet when we need custom comparison because + // it uses the Vector's implementation of compare and hash which it gets from + // the type. + using ValueSetT = std:: + conditional_t, ValueSet>; + VectorPtr applyFlat( const SelectivityVector& rows, const VectorPtr& arg, @@ -143,7 +153,7 @@ class ArrayDistinctFunction : public exec::VectorFunction { auto* rawNewOffsets = newOffsets->asMutable(); // Process the rows: store unique values in the hash table. - ValueSet uniqueSet; + ValueSetT uniqueSet; rows.applyToSelected([&](vector_size_t row) { auto size = arrayVector->sizeAt(row); @@ -159,7 +169,7 @@ class ArrayDistinctFunction : public exec::VectorFunction { } } else { bool unique; - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v>) { unique = uniqueSet.insert(elements->base(), elements->index(i)); } else { auto value = elements->valueAt(i); @@ -193,7 +203,7 @@ class ArrayDistinctFunction : public exec::VectorFunction { }; template <> -VectorPtr ArrayDistinctFunction::applyFlat( +VectorPtr ArrayDistinctFunction::applyFlat( const SelectivityVector& rows, const VectorPtr& arg, exec::EvalCtx& context) const { @@ -259,11 +269,17 @@ void validateType(const std::vector& inputArgs) { // Create function template based on type. template std::shared_ptr createTyped( - const std::vector& inputArgs) { + const std::vector& inputArgs, + const TypePtr& elementType) { VELOX_CHECK_EQ(inputArgs.size(), 1); using T = typename TypeTraits::NativeType; - return std::make_shared>(); + + if (elementType->providesCustomComparison()) { + return std::make_shared>(); + } else { + return std::make_shared>(); + } } // Create function. @@ -282,7 +298,7 @@ std::shared_ptr create( } return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( - createTyped, elementType->kind(), inputArgs); + createTyped, elementType->kind(), inputArgs, elementType); } // Define function signature. diff --git a/velox/functions/prestosql/tests/ArrayDistinctTest.cpp b/velox/functions/prestosql/tests/ArrayDistinctTest.cpp index 796085daac99..25c8acb8eebd 100644 --- a/velox/functions/prestosql/tests/ArrayDistinctTest.cpp +++ b/velox/functions/prestosql/tests/ArrayDistinctTest.cpp @@ -16,6 +16,7 @@ #include #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" using namespace facebook::velox; using namespace facebook::velox::test; @@ -369,3 +370,69 @@ TEST_F(ArrayDistinctTest, unknownType) { result = evaluate("array_distinct(c0)", makeRowVector({nullArrayVector})); assertEqualVectors(expected, result); } + +TEST_F(ArrayDistinctTest, timestampWithTimezone) { + const auto testArrayDistinct = + [this]( + const std::vector>& inputArray, + const std::vector>& expectedArray) { + const auto input = makeRowVector({makeArrayVector( + {0}, + makeNullableFlatVector(inputArray, TIMESTAMP_WITH_TIME_ZONE()))}); + const auto expected = makeArrayVector( + {0}, + makeNullableFlatVector(expectedArray, TIMESTAMP_WITH_TIME_ZONE())); + + assertEqualVectors(expected, evaluate("array_distinct(c0)", input)); + }; + + testArrayDistinct({}, {}); + testArrayDistinct({pack(0, 0)}, {pack(0, 0)}); + testArrayDistinct({pack(1, 0)}, {pack(1, 0)}); + testArrayDistinct( + {pack(std::numeric_limits::min(), 0)}, + {pack(std::numeric_limits::min(), 0)}); + testArrayDistinct( + {pack(std::numeric_limits::max(), 0)}, + {pack(std::numeric_limits::max(), 0)}); + testArrayDistinct({std::nullopt}, {std::nullopt}); + testArrayDistinct({pack(-1, 0)}, {pack(-1, 0)}); + testArrayDistinct( + {pack(1, 3), pack(2, 2), pack(3, 1)}, + {pack(1, 3), pack(2, 2), pack(3, 1)}); + testArrayDistinct( + {pack(1, 0), pack(2, 1), pack(1, 2)}, {pack(1, 0), pack(2, 1)}); + testArrayDistinct({pack(1, 0), pack(1, 1), pack(1, 2)}, {pack(1, 0)}); + testArrayDistinct( + {pack(-1, 0), pack(-2, 1), pack(-3, 2)}, + {pack(-1, 0), pack(-2, 1), pack(-3, 2)}); + testArrayDistinct( + {pack(-1, 0), pack(-2, 1), pack(-1, 2)}, {pack(-1, 0), pack(-2, 1)}); + testArrayDistinct({pack(-1, 0), pack(-1, 1), pack(-1, 2)}, {pack(-1, 0)}); + testArrayDistinct({std::nullopt, std::nullopt, std::nullopt}, {std::nullopt}); + testArrayDistinct( + {pack(1, 0), pack(2, 1), pack(-2, 2), pack(1, 3)}, + {pack(1, 0), pack(2, 1), pack(-2, 2)}); + testArrayDistinct( + {pack(1, 0), + pack(1, 1), + pack(-2, 2), + pack(-2, 3), + pack(-2, 4), + pack(4, 5), + pack(8, 6)}, + {pack(1, 0), pack(-2, 2), pack(4, 5), pack(8, 6)}); + testArrayDistinct( + {pack(3, 0), pack(8, 1), std::nullopt}, + {pack(3, 0), pack(8, 1), std::nullopt}); + testArrayDistinct( + {pack(1, 0), + pack(2, 1), + pack(3, 2), + std::nullopt, + pack(4, 3), + pack(1, 4), + pack(2, 5), + std::nullopt}, + {pack(1, 0), pack(2, 1), pack(3, 2), std::nullopt, pack(4, 3)}); +}