diff --git a/velox/functions/prestosql/Comparisons.h b/velox/functions/prestosql/Comparisons.h index b5059a0bf195c..a8caaaa836ae2 100644 --- a/velox/functions/prestosql/Comparisons.h +++ b/velox/functions/prestosql/Comparisons.h @@ -17,7 +17,6 @@ #include "velox/common/base/CompareFlags.h" #include "velox/functions/Macros.h" -#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/type/FloatingPointUtil.h" namespace facebook::velox::functions { @@ -37,18 +36,6 @@ namespace facebook::velox::functions { } \ }; -#define VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(Name, tsExpr, TResult) \ - template \ - struct Name##TimestampWithTimezone { \ - VELOX_DEFINE_FUNCTION_TYPES(T); \ - FOLLY_ALWAYS_INLINE void call( \ - bool& result, \ - const arg_type& lhs, \ - const arg_type& rhs) { \ - result = (tsExpr); \ - } \ - }; - VELOX_GEN_BINARY_EXPR( LtFunction, lhs < rhs, @@ -66,11 +53,6 @@ VELOX_GEN_BINARY_EXPR( lhs >= rhs, util::floating_point::NaNAwareGreaterThanEqual{}(lhs, rhs)); -VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(LtFunction, lhs < rhs, bool); -VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(GtFunction, lhs > rhs, bool); -VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(LteFunction, lhs <= rhs, bool); -VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(GteFunction, lhs >= rhs, bool); - #undef VELOX_GEN_BINARY_EXPR #undef VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE @@ -128,18 +110,6 @@ struct EqFunction { } }; -template -struct EqFunctionTimestampWithTimezone { - VELOX_DEFINE_FUNCTION_TYPES(T); - - void call( - bool& result, - const arg_type& lhs, - const arg_type& rhs) { - result = lhs == rhs; - } -}; - template struct NeqFunction { VELOX_DEFINE_FUNCTION_TYPES(T); @@ -168,18 +138,6 @@ struct NeqFunction { } }; -template -struct NeqFunctionTimestampWithTimezone { - VELOX_DEFINE_FUNCTION_TYPES(T); - - void call( - bool& result, - const arg_type& lhs, - const arg_type& rhs) { - result = lhs != rhs; - } -}; - template struct BetweenFunction { template @@ -195,17 +153,4 @@ struct BetweenFunction { } }; -template -struct BetweenFunctionTimestampWithTimezone { - VELOX_DEFINE_FUNCTION_TYPES(TExec); - - void call( - bool& result, - const arg_type& value, - const arg_type& low, - const arg_type& high) { - result = value >= low && value <= high; - } -}; - } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/GreatestLeast.h b/velox/functions/prestosql/GreatestLeast.h index 23070805933c3..a5a0a23e00d56 100644 --- a/velox/functions/prestosql/GreatestLeast.h +++ b/velox/functions/prestosql/GreatestLeast.h @@ -16,8 +16,8 @@ #pragma once #include +#include "velox/expression/ComplexViewTypes.h" #include "velox/functions/Macros.h" -#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" namespace facebook::velox::functions { namespace details { @@ -41,10 +41,10 @@ struct ExtremeValueFunction { out_type& result, const arg_type& firstElement, const arg_type>& remainingElement) { - auto currentValue = firstElement; + auto currentValue = extractValue(firstElement); for (auto element : remainingElement) { - auto candidateValue = element.value(); + auto candidateValue = extractValue(element.value()); if constexpr (isLeast) { if (smallerThan(candidateValue, currentValue)) { @@ -61,6 +61,16 @@ struct ExtremeValueFunction { } private: + template + auto extractValue(const U& wrapper) const { + return wrapper; + } + + int64_t extractValue( + const exec::CustomTypeWithCustomComparisonView& wrapper) const { + return *wrapper; + } + template bool greaterThan(const K& lhs, const K& rhs) const { if constexpr (std::is_same_v || std::is_same_v) { @@ -91,48 +101,12 @@ struct ExtremeValueFunction { return lhs < rhs; } }; - -template -struct ExtremeValueFunctionTimestampWithTimezone { - VELOX_DEFINE_FUNCTION_TYPES(TExec); - - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& firstElement, - const arg_type>& remainingElement) { - auto currentValue = *firstElement; - - for (auto element : remainingElement) { - auto candidateValue = *element.value(); - - if constexpr (isLeast) { - if (unpackMillisUtc(candidateValue) < unpackMillisUtc(currentValue)) { - currentValue = candidateValue; - } - } else { - if (unpackMillisUtc(candidateValue) > unpackMillisUtc(currentValue)) { - currentValue = candidateValue; - } - } - } - - result = currentValue; - } -}; } // namespace details template using LeastFunction = details::ExtremeValueFunction; -template -using LeastFunctionTimestampWithTimezone = - details::ExtremeValueFunctionTimestampWithTimezone; - template using GreatestFunction = details::ExtremeValueFunction; -template -using GreatestFunctionTimestampWithTimezone = - details::ExtremeValueFunctionTimestampWithTimezone; - } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp index d44d289e38262..4f964d27270ea 100644 --- a/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp @@ -16,6 +16,7 @@ #include "velox/functions/Registerer.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/prestosql/Comparisons.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/type/Type.h" namespace facebook::velox::functions { @@ -27,6 +28,8 @@ void registerNonSimdizableScalar(const std::vector& aliases) { registerFunction(aliases); registerFunction(aliases); registerFunction(aliases); + registerFunction( + aliases); } } // namespace @@ -38,52 +41,22 @@ void registerComparisonFunctions(const std::string& prefix) { registerNonSimdizableScalar({prefix + "eq"}); VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_eq, prefix + "eq"); registerFunction, Generic>({prefix + "eq"}); - registerFunction< - EqFunctionTimestampWithTimezone, - bool, - TimestampWithTimezone, - TimestampWithTimezone>({prefix + "eq"}); registerNonSimdizableScalar({prefix + "neq"}); VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_neq, prefix + "neq"); registerFunction, Generic>( {prefix + "neq"}); - registerFunction< - NeqFunctionTimestampWithTimezone, - bool, - TimestampWithTimezone, - TimestampWithTimezone>({prefix + "neq"}); registerNonSimdizableScalar({prefix + "lt"}); - registerFunction< - LtFunctionTimestampWithTimezone, - bool, - TimestampWithTimezone, - TimestampWithTimezone>({prefix + "lt"}); VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_lt, prefix + "lt"); registerNonSimdizableScalar({prefix + "gt"}); - registerFunction< - GtFunctionTimestampWithTimezone, - bool, - TimestampWithTimezone, - TimestampWithTimezone>({prefix + "gt"}); VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_gt, prefix + "gt"); registerNonSimdizableScalar({prefix + "lte"}); - registerFunction< - LteFunctionTimestampWithTimezone, - bool, - TimestampWithTimezone, - TimestampWithTimezone>({prefix + "lte"}); VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_lte, prefix + "lte"); registerNonSimdizableScalar({prefix + "gte"}); - registerFunction< - GteFunctionTimestampWithTimezone, - bool, - TimestampWithTimezone, - TimestampWithTimezone>({prefix + "gte"}); VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_gte, prefix + "gte"); registerFunction, Generic>( @@ -132,7 +105,7 @@ void registerComparisonFunctions(const std::string& prefix) { IntervalYearMonth, IntervalYearMonth>({prefix + "between"}); registerFunction< - BetweenFunctionTimestampWithTimezone, + BetweenFunction, bool, TimestampWithTimezone, TimestampWithTimezone, diff --git a/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp b/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp index bb154865e672a..dda91e413da91 100644 --- a/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp @@ -21,6 +21,7 @@ #include "velox/functions/prestosql/GreatestLeast.h" #include "velox/functions/prestosql/InPredicate.h" #include "velox/functions/prestosql/Reduce.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" namespace facebook::velox::functions { @@ -55,18 +56,7 @@ void registerAllGreatestLeastFunctions(const std::string& prefix) { registerGreatestLeastFunction>(prefix); registerGreatestLeastFunction(prefix); registerGreatestLeastFunction(prefix); - - registerFunction< - GreatestFunctionTimestampWithTimezone, - TimestampWithTimezone, - TimestampWithTimezone, - Variadic>({prefix + "greatest"}); - - registerFunction< - LeastFunctionTimestampWithTimezone, - TimestampWithTimezone, - TimestampWithTimezone, - Variadic>({prefix + "least"}); + registerGreatestLeastFunction(prefix); } } // namespace diff --git a/velox/functions/prestosql/tests/ComparisonsTest.cpp b/velox/functions/prestosql/tests/ComparisonsTest.cpp index 38098b5628dd9..3e2a502afaea7 100644 --- a/velox/functions/prestosql/tests/ComparisonsTest.cpp +++ b/velox/functions/prestosql/tests/ComparisonsTest.cpp @@ -19,6 +19,7 @@ #include "velox/functions/Udf.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/type/tests/utils/CustomTypesForTesting.h" #include "velox/type/tz/TimeZoneMap.h"