Skip to content

Commit

Permalink
custom comparison clean up simple UDFs
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinwilfong authored and facebook-github-bot committed Oct 1, 2024
1 parent c917f49 commit 66c23ef
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 137 deletions.
55 changes: 0 additions & 55 deletions velox/functions/prestosql/Comparisons.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

#include "velox/common/base/CompareFlags.h"
#include "velox/functions/Macros.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/type/FloatingPointUtil.h"

namespace facebook::velox::functions {
Expand All @@ -37,18 +36,6 @@ namespace facebook::velox::functions {
} \
};

#define VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(Name, tsExpr, TResult) \
template <typename T> \
struct Name##TimestampWithTimezone { \
VELOX_DEFINE_FUNCTION_TYPES(T); \
FOLLY_ALWAYS_INLINE void call( \
bool& result, \
const arg_type<TimestampWithTimezone>& lhs, \
const arg_type<TimestampWithTimezone>& rhs) { \
result = (tsExpr); \
} \
};

VELOX_GEN_BINARY_EXPR(
LtFunction,
lhs < rhs,
Expand All @@ -66,11 +53,6 @@ VELOX_GEN_BINARY_EXPR(
lhs >= rhs,
util::floating_point::NaNAwareGreaterThanEqual<TInput>{}(lhs, rhs));

VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(LtFunction, lhs < rhs, bool);
VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(GtFunction, lhs > rhs, bool);
VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(LteFunction, lhs <= rhs, bool);
VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(GteFunction, lhs >= rhs, bool);

#undef VELOX_GEN_BINARY_EXPR
#undef VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE

Expand Down Expand Up @@ -128,18 +110,6 @@ struct EqFunction {
}
};

template <typename T>
struct EqFunctionTimestampWithTimezone {
VELOX_DEFINE_FUNCTION_TYPES(T);

void call(
bool& result,
const arg_type<TimestampWithTimezone>& lhs,
const arg_type<TimestampWithTimezone>& rhs) {
result = lhs == rhs;
}
};

template <typename T>
struct NeqFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);
Expand Down Expand Up @@ -168,18 +138,6 @@ struct NeqFunction {
}
};

template <typename T>
struct NeqFunctionTimestampWithTimezone {
VELOX_DEFINE_FUNCTION_TYPES(T);

void call(
bool& result,
const arg_type<TimestampWithTimezone>& lhs,
const arg_type<TimestampWithTimezone>& rhs) {
result = lhs != rhs;
}
};

template <typename TExec>
struct BetweenFunction {
template <typename T>
Expand All @@ -195,17 +153,4 @@ struct BetweenFunction {
}
};

template <typename TExec>
struct BetweenFunctionTimestampWithTimezone {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

void call(
bool& result,
const arg_type<TimestampWithTimezone>& value,
const arg_type<TimestampWithTimezone>& low,
const arg_type<TimestampWithTimezone>& high) {
result = value >= low && value <= high;
}
};

} // namespace facebook::velox::functions
52 changes: 13 additions & 39 deletions velox/functions/prestosql/GreatestLeast.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#pragma once

#include <cmath>
#include "velox/expression/ComplexViewTypes.h"
#include "velox/functions/Macros.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"

namespace facebook::velox::functions {
namespace details {
Expand All @@ -41,10 +41,10 @@ struct ExtremeValueFunction {
out_type<T>& result,
const arg_type<T>& firstElement,
const arg_type<Variadic<T>>& remainingElement) {
auto currentValue = firstElement;
auto currentValue = extractValue(firstElement);

for (auto element : remainingElement) {
auto candidateValue = element.value();
auto candidateValue = extractValue(element.value());

if constexpr (isLeast) {
if (smallerThan(candidateValue, currentValue)) {
Expand All @@ -61,6 +61,16 @@ struct ExtremeValueFunction {
}

private:
template <typename U>
auto extractValue(const U& wrapper) const {
return wrapper;
}

int64_t extractValue(
const exec::CustomTypeWithCustomComparisonView<int64_t>& wrapper) const {
return *wrapper;
}

template <typename K>
bool greaterThan(const K& lhs, const K& rhs) const {
if constexpr (std::is_same_v<K, double> || std::is_same_v<K, float>) {
Expand Down Expand Up @@ -91,48 +101,12 @@ struct ExtremeValueFunction {
return lhs < rhs;
}
};

template <typename TExec, bool isLeast>
struct ExtremeValueFunctionTimestampWithTimezone {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

FOLLY_ALWAYS_INLINE void call(
out_type<TimestampWithTimezone>& result,
const arg_type<TimestampWithTimezone>& firstElement,
const arg_type<Variadic<TimestampWithTimezone>>& remainingElement) {
auto currentValue = *firstElement;

for (auto element : remainingElement) {
auto candidateValue = *element.value();

if constexpr (isLeast) {
if (unpackMillisUtc(candidateValue) < unpackMillisUtc(currentValue)) {
currentValue = candidateValue;
}
} else {
if (unpackMillisUtc(candidateValue) > unpackMillisUtc(currentValue)) {
currentValue = candidateValue;
}
}
}

result = currentValue;
}
};
} // namespace details

template <typename TExec, typename T>
using LeastFunction = details::ExtremeValueFunction<TExec, T, true>;

template <typename TExec>
using LeastFunctionTimestampWithTimezone =
details::ExtremeValueFunctionTimestampWithTimezone<TExec, true>;

template <typename TExec, typename T>
using GreatestFunction = details::ExtremeValueFunction<TExec, T, false>;

template <typename TExec>
using GreatestFunctionTimestampWithTimezone =
details::ExtremeValueFunctionTimestampWithTimezone<TExec, false>;

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "velox/functions/Registerer.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/prestosql/Comparisons.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/type/Type.h"

namespace facebook::velox::functions {
Expand All @@ -27,6 +28,8 @@ void registerNonSimdizableScalar(const std::vector<std::string>& aliases) {
registerFunction<T, TReturn, Varbinary, Varbinary>(aliases);
registerFunction<T, TReturn, bool, bool>(aliases);
registerFunction<T, TReturn, Timestamp, Timestamp>(aliases);
registerFunction<T, TReturn, TimestampWithTimezone, TimestampWithTimezone>(
aliases);
}
} // namespace

Expand All @@ -38,52 +41,22 @@ void registerComparisonFunctions(const std::string& prefix) {
registerNonSimdizableScalar<EqFunction, bool>({prefix + "eq"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_eq, prefix + "eq");
registerFunction<EqFunction, bool, Generic<T1>, Generic<T1>>({prefix + "eq"});
registerFunction<
EqFunctionTimestampWithTimezone,
bool,
TimestampWithTimezone,
TimestampWithTimezone>({prefix + "eq"});

registerNonSimdizableScalar<NeqFunction, bool>({prefix + "neq"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_neq, prefix + "neq");
registerFunction<NeqFunction, bool, Generic<T1>, Generic<T1>>(
{prefix + "neq"});
registerFunction<
NeqFunctionTimestampWithTimezone,
bool,
TimestampWithTimezone,
TimestampWithTimezone>({prefix + "neq"});

registerNonSimdizableScalar<LtFunction, bool>({prefix + "lt"});
registerFunction<
LtFunctionTimestampWithTimezone,
bool,
TimestampWithTimezone,
TimestampWithTimezone>({prefix + "lt"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_lt, prefix + "lt");

registerNonSimdizableScalar<GtFunction, bool>({prefix + "gt"});
registerFunction<
GtFunctionTimestampWithTimezone,
bool,
TimestampWithTimezone,
TimestampWithTimezone>({prefix + "gt"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_gt, prefix + "gt");

registerNonSimdizableScalar<LteFunction, bool>({prefix + "lte"});
registerFunction<
LteFunctionTimestampWithTimezone,
bool,
TimestampWithTimezone,
TimestampWithTimezone>({prefix + "lte"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_lte, prefix + "lte");

registerNonSimdizableScalar<GteFunction, bool>({prefix + "gte"});
registerFunction<
GteFunctionTimestampWithTimezone,
bool,
TimestampWithTimezone,
TimestampWithTimezone>({prefix + "gte"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_gte, prefix + "gte");

registerFunction<DistinctFromFunction, bool, Generic<T1>, Generic<T1>>(
Expand Down Expand Up @@ -132,7 +105,7 @@ void registerComparisonFunctions(const std::string& prefix) {
IntervalYearMonth,
IntervalYearMonth>({prefix + "between"});
registerFunction<
BetweenFunctionTimestampWithTimezone,
BetweenFunction,
bool,
TimestampWithTimezone,
TimestampWithTimezone,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "velox/functions/prestosql/GreatestLeast.h"
#include "velox/functions/prestosql/InPredicate.h"
#include "velox/functions/prestosql/Reduce.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"

namespace facebook::velox::functions {

Expand Down Expand Up @@ -55,18 +56,7 @@ void registerAllGreatestLeastFunctions(const std::string& prefix) {
registerGreatestLeastFunction<ShortDecimal<P1, S1>>(prefix);
registerGreatestLeastFunction<Date>(prefix);
registerGreatestLeastFunction<Timestamp>(prefix);

registerFunction<
GreatestFunctionTimestampWithTimezone,
TimestampWithTimezone,
TimestampWithTimezone,
Variadic<TimestampWithTimezone>>({prefix + "greatest"});

registerFunction<
LeastFunctionTimestampWithTimezone,
TimestampWithTimezone,
TimestampWithTimezone,
Variadic<TimestampWithTimezone>>({prefix + "least"});
registerGreatestLeastFunction<TimestampWithTimezone>(prefix);
}
} // namespace

Expand Down
1 change: 1 addition & 0 deletions velox/functions/prestosql/tests/ComparisonsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "velox/functions/Udf.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/type/tests/utils/CustomTypesForTesting.h"
#include "velox/type/tz/TimeZoneMap.h"

Expand Down

0 comments on commit 66c23ef

Please sign in to comment.