diff --git a/velox/functions/prestosql/tests/ArraySortTest.cpp b/velox/functions/prestosql/tests/ArraySortTest.cpp index 574b3d5913024..95d7a9bb254cc 100644 --- a/velox/functions/prestosql/tests/ArraySortTest.cpp +++ b/velox/functions/prestosql/tests/ArraySortTest.cpp @@ -14,7 +14,10 @@ * limitations under the License. */ #include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/Macros.h" +#include "velox/functions/Registerer.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include #include @@ -722,6 +725,149 @@ TEST_F(ArraySortTest, failOnRowNullCompare) { } } +TEST_F(ArraySortTest, timestampWithTimezone) { + auto testArraySort = + [this]( + const std::vector>& inputArray, + const std::vector> expectedAscArray, + const std::vector> expectedDescArray) { + const auto input = makeRowVector({makeArrayVector( + {0}, + makeNullableFlatVector( + inputArray, TIMESTAMP_WITH_TIME_ZONE()))}); + const auto expectedAsc = makeArrayVector( + {0}, + makeNullableFlatVector( + expectedAscArray, TIMESTAMP_WITH_TIME_ZONE())); + const auto expectedDesc = makeArrayVector( + {0}, + makeNullableFlatVector( + expectedDescArray, TIMESTAMP_WITH_TIME_ZONE())); + + auto resultAsc = evaluate("array_sort(c0)", input); + assertEqualVectors(expectedAsc, resultAsc); + + auto resultDesc = evaluate("array_sort_desc(c0)", input); + assertEqualVectors(expectedDesc, resultDesc); + }; + + testArraySort( + {pack(2, 0), pack(1, 1), pack(0, 2)}, + {pack(0, 2), pack(1, 1), pack(2, 0)}, + {pack(2, 0), pack(1, 1), pack(0, 2)}); + testArraySort( + {pack(0, 0), pack(1, 1), pack(2, 2)}, + {pack(0, 0), pack(1, 1), pack(2, 2)}, + {pack(2, 2), pack(1, 1), pack(0, 0)}); + testArraySort( + {pack(0, 0), pack(0, 1), pack(0, 2)}, + {pack(0, 0), pack(0, 1), pack(0, 2)}, + {pack(0, 0), pack(0, 1), pack(0, 2)}); + testArraySort( + {pack(1, 0), pack(0, 1), pack(2, 2)}, + {pack(0, 1), pack(1, 0), pack(2, 2)}, + {pack(2, 2), pack(1, 0), pack(0, 1)}); + testArraySort( + {std::nullopt, pack(1, 0), pack(0, 1), pack(2, 2)}, + {pack(0, 1), pack(1, 0), pack(2, 2), std::nullopt}, + {pack(2, 2), pack(1, 0), pack(0, 1), std::nullopt}); + testArraySort( + {std::nullopt, std::nullopt, pack(1, 2), pack(0, 1), pack(2, 0)}, + {pack(0, 1), pack(1, 2), pack(2, 0), std::nullopt, std::nullopt}, + {pack(2, 0), pack(1, 2), pack(0, 1), std::nullopt, std::nullopt}); + testArraySort( + {std::nullopt, pack(1, 1), pack(0, 2), std::nullopt, pack(2, 0)}, + {pack(0, 2), pack(1, 1), pack(2, 0), std::nullopt, std::nullopt}, + {pack(2, 0), pack(1, 1), pack(0, 2), std::nullopt, std::nullopt}); + testArraySort( + {pack(1, 1), std::nullopt, pack(0, 0), pack(2, 2), std::nullopt}, + {pack(0, 0), pack(1, 1), pack(2, 2), std::nullopt, std::nullopt}, + {pack(2, 2), pack(1, 1), pack(0, 0), std::nullopt, std::nullopt}); + testArraySort( + {std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt}, + {std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt}, + {std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt}); +} + +template +struct TimeZoneFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call( + int64_t& result, + const arg_type& ts) { + result = unpackZoneKeyId(*ts); + } +}; + +TEST_F(ArraySortTest, timestampWithTimezoneWithLambda) { + registerFunction( + {"timezone"}); + + auto testArraySort = + [this]( + const std::vector>& inputArray, + const std::vector> expectedAscArray, + const std::vector> expectedDescArray) { + const auto input = makeRowVector({makeArrayVector( + {0}, + makeNullableFlatVector( + inputArray, TIMESTAMP_WITH_TIME_ZONE()))}); + const auto expectedAsc = makeArrayVector( + {0}, + makeNullableFlatVector( + expectedAscArray, TIMESTAMP_WITH_TIME_ZONE())); + const auto expectedDesc = makeArrayVector( + {0}, + makeNullableFlatVector( + expectedDescArray, TIMESTAMP_WITH_TIME_ZONE())); + + auto resultAsc = evaluate("array_sort(c0, x -> timezone(x))", input); + assertEqualVectors(expectedAsc, resultAsc); + + auto resultDesc = + evaluate("array_sort_desc(c0, x -> timezone(x))", input); + assertEqualVectors(expectedDesc, resultDesc); + }; + + testArraySort( + {pack(2, 0), pack(1, 1), pack(0, 2)}, + {pack(2, 0), pack(1, 1), pack(0, 2)}, + {pack(0, 2), pack(1, 1), pack(2, 0)}); + testArraySort( + {pack(0, 0), pack(1, 1), pack(2, 2)}, + {pack(0, 0), pack(1, 1), pack(2, 2)}, + {pack(2, 2), pack(1, 1), pack(0, 0)}); + testArraySort( + {pack(0, 0), pack(0, 1), pack(0, 2)}, + {pack(0, 0), pack(0, 1), pack(0, 2)}, + {pack(0, 2), pack(0, 1), pack(0, 0)}); + testArraySort( + {pack(1, 0), pack(0, 1), pack(2, 2)}, + {pack(1, 0), pack(0, 1), pack(2, 2)}, + {pack(2, 2), pack(0, 1), pack(1, 0)}); + testArraySort( + {std::nullopt, pack(1, 0), pack(0, 1), pack(2, 2)}, + {pack(1, 0), pack(0, 1), pack(2, 2), std::nullopt}, + {pack(2, 2), pack(0, 1), pack(1, 0), std::nullopt}); + testArraySort( + {std::nullopt, std::nullopt, pack(1, 2), pack(0, 1), pack(2, 0)}, + {pack(2, 0), pack(0, 1), pack(1, 2), std::nullopt, std::nullopt}, + {pack(1, 2), pack(0, 1), pack(2, 0), std::nullopt, std::nullopt}); + testArraySort( + {std::nullopt, pack(1, 1), pack(0, 2), std::nullopt, pack(2, 0)}, + {pack(2, 0), pack(1, 1), pack(0, 2), std::nullopt, std::nullopt}, + {pack(0, 2), pack(1, 1), pack(2, 0), std::nullopt, std::nullopt}); + testArraySort( + {pack(1, 1), std::nullopt, pack(0, 0), pack(2, 2), std::nullopt}, + {pack(0, 0), pack(1, 1), pack(2, 2), std::nullopt, std::nullopt}, + {pack(2, 2), pack(1, 1), pack(0, 0), std::nullopt, std::nullopt}); + testArraySort( + {std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt}, + {std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt}, + {std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt}); +} + TEST_F(ArraySortTest, floatingPointExtremes) { testFloatingPoint(); testFloatingPoint();