Skip to content

Commit

Permalink
feat: Add mathematical operators for IntervalYearMonth type
Browse files Browse the repository at this point in the history
  • Loading branch information
pramodsatya authored and Pramod Satya committed Nov 21, 2024
1 parent ebfb1e5 commit d1bac96
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 23 deletions.
60 changes: 51 additions & 9 deletions velox/functions/prestosql/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ inline constexpr int kMinRadix = 2;
inline constexpr int kMaxRadix = 36;
inline constexpr long kLongMax = std::numeric_limits<int64_t>::max();
inline constexpr long kLongMin = std::numeric_limits<int64_t>::min();
inline constexpr long kIntegerMax = std::numeric_limits<int32_t>::max();
inline constexpr long kIntegerMin = std::numeric_limits<int32_t>::min();

inline constexpr char digits[36] = {
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b',
Expand Down Expand Up @@ -73,7 +75,8 @@ struct MultiplyFunction {
}
};

// Multiply function for IntervalDayTime * Double and Double * IntervalDayTime.
// Multiply function for IntervalDayTime * Double, Double * IntervalDayTime,
// IntervalYearMonth * Double and Double * IntervalYearMonth.
template <typename T>
struct IntervalMultiplyFunction {
FOLLY_ALWAYS_INLINE double sanitizeInput(double d) {
Expand All @@ -84,25 +87,38 @@ struct IntervalMultiplyFunction {
}

template <
typename TResult,
typename T1,
typename T2,
typename = std::enable_if_t<
(std::is_same_v<T1, int64_t> && std::is_same_v<T2, double>) ||
(std::is_same_v<T1, double> && std::is_same_v<T2, int64_t>)>>
FOLLY_ALWAYS_INLINE void call(int64_t& result, T1 a, T2 b) {
(std::is_same_v<T1, double> && std::is_same_v<T2, int64_t>) ||
(std::is_same_v<T1, int32_t> && std::is_same_v<T2, double>) ||
(std::is_same_v<T1, double> && std::is_same_v<T2, int32_t>)>>
FOLLY_ALWAYS_INLINE void call(TResult& result, T1 a, T2 b) {
double resultDouble;
if constexpr (std::is_same_v<T1, double>) {
resultDouble = sanitizeInput(a) * b;
} else {
resultDouble = sanitizeInput(b) * a;
}

if LIKELY (
std::isfinite(resultDouble) && resultDouble >= kLongMin &&
resultDouble <= kMaxDoubleBelowInt64Max) {
result = int64_t(resultDouble);
if constexpr (std::is_same_v<TResult, int64_t>) {
if LIKELY (
std::isfinite(resultDouble) && resultDouble >= kLongMin &&
resultDouble <= kMaxDoubleBelowInt64Max) {
result = int64_t(resultDouble);
} else {
result = resultDouble > 0 ? kLongMax : kLongMin;
}
} else {
result = resultDouble > 0 ? kLongMax : kLongMin;
if LIKELY (
std::isfinite(resultDouble) && resultDouble >= kIntegerMin &&
resultDouble <= kIntegerMax) {
result = int32_t(resultDouble);
} else {
result = resultDouble > 0 ? kIntegerMax : kIntegerMin;
}
}
}
};
Expand All @@ -124,7 +140,7 @@ struct DivideFunction {
};

template <typename T>
struct IntervalDivideFunction {
struct IntervalDayTimeDivideFunction {
FOLLY_ALWAYS_INLINE void call(int64_t& result, int64_t a, double b)
// Depend on compiler have correct behaviour for divide by zero
#if defined(__has_feature)
Expand All @@ -149,6 +165,32 @@ struct IntervalDivideFunction {
}
};

template <typename T>
struct IntervalYearMonthDivideFunction {
FOLLY_ALWAYS_INLINE void call(int32_t& result, int32_t a, double b)
// Depend on compiler have correct behaviour for divide by zero
#if defined(__has_feature)
#if __has_feature(__address_sanitizer__)
__attribute__((__no_sanitize__("float-divide-by-zero")))
__attribute__((__no_sanitize__("float-cast-overflow")))
#endif
#endif
{
if UNLIKELY (a == 0 || std::isnan(b) || !std::isfinite(b)) {
result = 0;
return;
}
double resultDouble = a / b;
if LIKELY (
std::isfinite(resultDouble) && resultDouble >= kIntegerMin &&
resultDouble <= kIntegerMax) {
result = int32_t(resultDouble);
} else {
result = resultDouble > 0 ? kIntegerMax : kIntegerMin;
}
}
};

template <typename T>
struct ModulusFunction {
template <typename TInput>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,22 @@ void registerMathOperators(const std::string& prefix = "") {
IntervalDayTime,
IntervalDayTime,
IntervalDayTime>({prefix + "plus"});
registerFunction<
PlusFunction,
IntervalYearMonth,
IntervalYearMonth,
IntervalYearMonth>({prefix + "plus"});
registerBinaryFloatingPoint<MinusFunction>({prefix + "minus"});
registerFunction<
MinusFunction,
IntervalDayTime,
IntervalDayTime,
IntervalDayTime>({prefix + "minus"});
registerFunction<
MinusFunction,
IntervalYearMonth,
IntervalYearMonth,
IntervalYearMonth>({prefix + "minus"});
registerBinaryFloatingPoint<MultiplyFunction>({prefix + "multiply"});
registerFunction<MultiplyFunction, IntervalDayTime, IntervalDayTime, int64_t>(
{prefix + "multiply"});
Expand All @@ -49,12 +59,37 @@ void registerMathOperators(const std::string& prefix = "") {
IntervalDayTime,
double,
IntervalDayTime>({prefix + "multiply"});
registerFunction<
MultiplyFunction,
IntervalYearMonth,
IntervalYearMonth,
int32_t>({prefix + "multiply"});
registerFunction<
MultiplyFunction,
IntervalYearMonth,
int32_t,
IntervalYearMonth>({prefix + "multiply"});
registerFunction<
IntervalMultiplyFunction,
IntervalYearMonth,
IntervalYearMonth,
double>({prefix + "multiply"});
registerFunction<
IntervalMultiplyFunction,
IntervalYearMonth,
double,
IntervalYearMonth>({prefix + "multiply"});
registerBinaryFloatingPoint<DivideFunction>({prefix + "divide"});
registerFunction<
IntervalDivideFunction,
IntervalDayTimeDivideFunction,
IntervalDayTime,
IntervalDayTime,
double>({prefix + "divide"});
registerFunction<
IntervalYearMonthDivideFunction,
IntervalYearMonth,
IntervalYearMonth,
double>({prefix + "divide"});
registerBinaryFloatingPoint<ModulusFunction>({prefix + "mod"});
}

Expand Down
101 changes: 88 additions & 13 deletions velox/functions/prestosql/tests/ArithmeticTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ constexpr float kInfF = std::numeric_limits<float>::infinity();
constexpr float kNanF = std::numeric_limits<float>::quiet_NaN();
constexpr int64_t kLongMax = std::numeric_limits<int64_t>::max();
constexpr int64_t kLongMin = std::numeric_limits<int64_t>::min();
constexpr int32_t kIntegerMax = std::numeric_limits<int32_t>::max();
constexpr int32_t kIntegerMin = std::numeric_limits<int32_t>::min();

MATCHER(IsNan, "is NaN") {
return arg && std::isnan(*arg);
Expand Down Expand Up @@ -93,27 +95,47 @@ class ArithmeticTest : public functions::test::FunctionBaseTest {
};

TEST_F(ArithmeticTest, plus) {
// Test plus for intervals.
auto op1 = makeNullableFlatVector<int64_t>(
// Test plus for interval day time.
auto dayTimeOp1 = makeNullableFlatVector<int64_t>(
{-1, 2, -3, 4, kLongMax, -1, std::nullopt, 0}, INTERVAL_DAY_TIME());
auto op2 = makeNullableFlatVector<int64_t>(
auto dayTimeOp2 = makeNullableFlatVector<int64_t>(
{2, -3, -1, 1, 1, kLongMin, 0, std::nullopt}, INTERVAL_DAY_TIME());
auto expected = makeNullableFlatVector<int64_t>(
auto dayTimeExpected = makeNullableFlatVector<int64_t>(
{1, -1, -4, 5, kLongMin, kLongMax, std::nullopt, std::nullopt},
INTERVAL_DAY_TIME());
assertExpression("c0 + c1", op1, op2, expected);
assertExpression("c0 + c1", dayTimeOp1, dayTimeOp2, dayTimeExpected);

// Test plus for interval year month.
auto yearMonthOp1 = makeNullableFlatVector<int32_t>(
{-1, 2, -3, 4, kIntegerMax, -1, std::nullopt, 0}, INTERVAL_YEAR_MONTH());
auto yearMonthOp2 = makeNullableFlatVector<int32_t>(
{2, -3, -1, 1, 1, kIntegerMin, 0, std::nullopt}, INTERVAL_YEAR_MONTH());
auto yearMonthExpected = makeNullableFlatVector<int32_t>(
{1, -1, -4, 5, kIntegerMin, kIntegerMax, std::nullopt, std::nullopt},
INTERVAL_YEAR_MONTH());
assertExpression("c0 + c1", yearMonthOp1, yearMonthOp2, yearMonthExpected);
}

TEST_F(ArithmeticTest, minus) {
// Test plus for intervals.
auto op1 = makeNullableFlatVector<int64_t>(
// Test minus for interval day time.
auto dayTimeOp1 = makeNullableFlatVector<int64_t>(
{-1, 2, -3, 4, kLongMin, -1, std::nullopt, 0}, INTERVAL_DAY_TIME());
auto op2 = makeNullableFlatVector<int64_t>(
auto dayTimeOp2 = makeNullableFlatVector<int64_t>(
{2, 3, -4, 1, 1, kLongMax, 0, std::nullopt}, INTERVAL_DAY_TIME());
auto expected = makeNullableFlatVector<int64_t>(
auto dayTimeExpected = makeNullableFlatVector<int64_t>(
{-3, -1, 1, 3, kLongMax, kLongMin, std::nullopt, std::nullopt},
INTERVAL_DAY_TIME());
assertExpression("c0 - c1", op1, op2, expected);
assertExpression("c0 - c1", dayTimeOp1, dayTimeOp2, dayTimeExpected);

// Test minus for interval year month.
auto yearMonthOp1 = makeNullableFlatVector<int32_t>(
{-1, 2, -3, 4, kIntegerMin, -1, std::nullopt, 0}, INTERVAL_YEAR_MONTH());
auto yearMonthOp2 = makeNullableFlatVector<int32_t>(
{2, 3, -4, 1, 1, kIntegerMax, 0, std::nullopt}, INTERVAL_YEAR_MONTH());
auto yearMonthExpected = makeNullableFlatVector<int32_t>(
{-3, -1, 1, 3, kIntegerMax, kIntegerMin, std::nullopt, std::nullopt},
INTERVAL_YEAR_MONTH());
assertExpression("c0 - c1", yearMonthOp1, yearMonthOp2, yearMonthExpected);
}

TEST_F(ArithmeticTest, divide)
Expand Down Expand Up @@ -144,7 +166,7 @@ __attribute__((__no_sanitize__("float-divide-by-zero")))
assertExpression<double>(
"c0 / c1", {10.5, 9.2, 0.0, 0.0}, {2, 0, 0, -1}, {5.25, kInf, kNan, 0.0});

// Test interval divided by double.
// Test interval day time and interval year month types divided by double.
auto intervalVector = makeNullableFlatVector<int64_t>(
{3, 6, 9, std::nullopt, 12, 15, 18, 21, 0, 0, 1}, INTERVAL_DAY_TIME());
auto doubleVector = makeNullableFlatVector<double>(
Expand All @@ -164,6 +186,13 @@ __attribute__((__no_sanitize__("float-divide-by-zero")))
INTERVAL_DAY_TIME());
assertExpression("c0 / c1", intervalVector, doubleVector, expected);

auto intervalYmVector = makeNullableFlatVector<int32_t>(
{3, 6, 9, std::nullopt, 12, 15, 18, 21, 0, 0, 1}, INTERVAL_YEAR_MONTH());
auto expectedYm = makeNullableFlatVector<int32_t>(
{6, -3, 1, std::nullopt, std::nullopt, 0, 0, 0, 0, 0, 100000000},
INTERVAL_YEAR_MONTH());
assertExpression("c0 / c1", intervalYmVector, doubleVector, expectedYm);

intervalVector = makeFlatVector<int64_t>(
{1, 1, 1, 1, kLongMax, kLongMin, kLongMax, kLongMin},
INTERVAL_DAY_TIME());
Expand All @@ -180,6 +209,21 @@ __attribute__((__no_sanitize__("float-divide-by-zero")))
kLongMax},
INTERVAL_DAY_TIME());
assertExpression("c0 / c1", intervalVector, doubleVector, expected);

intervalYmVector = makeFlatVector<int32_t>(
{1, 1, 1, 1, kIntegerMax, kIntegerMin, kIntegerMax, kIntegerMin},
INTERVAL_YEAR_MONTH());
expectedYm = makeFlatVector<int32_t>(
{kIntegerMax,
kIntegerMin,
kIntegerMax,
kIntegerMin,
kIntegerMax,
kIntegerMin,
kIntegerMin,
kIntegerMax},
INTERVAL_YEAR_MONTH());
assertExpression("c0 / c1", intervalYmVector, doubleVector, expectedYm);
}

TEST_F(ArithmeticTest, multiply) {
Expand All @@ -189,7 +233,7 @@ TEST_F(ArithmeticTest, multiply) {
{-1},
"integer overflow: -2147483648 * -1");

// Test multiplication of interval type with bigint.
// Test multiplication of interval day time type with bigint.
auto intervalVector = makeNullableFlatVector<int64_t>(
{1, 2, 3, std::nullopt, 10, 20}, INTERVAL_DAY_TIME());
auto bigintVector = makeNullableFlatVector<int64_t>(
Expand All @@ -199,7 +243,17 @@ TEST_F(ArithmeticTest, multiply) {
assertExpression("c0 * c1", intervalVector, bigintVector, expected);
assertExpression("c1 * c0", intervalVector, bigintVector, expected);

// Test multiplication of interval type with double.
// Test multiplication of interval year month type with integer.
auto intervalYmVector = makeNullableFlatVector<int32_t>(
{1, 2, 3, std::nullopt, 10, 20}, INTERVAL_YEAR_MONTH());
auto integerVector = makeNullableFlatVector<int32_t>(
{1, std::nullopt, 3, 4, kIntegerMax, kIntegerMin});
auto expectedYm = makeNullableFlatVector<int32_t>(
{1, std::nullopt, 9, std::nullopt, -10, 0}, INTERVAL_YEAR_MONTH());
assertExpression("c0 * c1", intervalYmVector, integerVector, expectedYm);
assertExpression("c1 * c0", intervalYmVector, integerVector, expectedYm);

// Test multiplication of interval day time type with double.
intervalVector = makeNullableFlatVector<int64_t>(
{1, 2, 3, std::nullopt, 10, 20, 30, 40, 1000, 1000, 1000, 1000},
INTERVAL_DAY_TIME());
Expand Down Expand Up @@ -232,6 +286,27 @@ TEST_F(ArithmeticTest, multiply) {
INTERVAL_DAY_TIME());
assertExpression("c0 * c1", intervalVector, doubleVector, expected);
assertExpression("c1 * c0", intervalVector, doubleVector, expected);

// Test multiplication of interval year month type with double.
intervalYmVector = makeNullableFlatVector<int32_t>(
{1, 2, 3, std::nullopt, 10, 20, 30, 40, 1000, 1000, 1000, 1000},
INTERVAL_YEAR_MONTH());
expectedYm = makeNullableFlatVector<int32_t>(
{-1,
4,
0,
std::nullopt,
0,
0,
kIntegerMax,
kIntegerMin,
kIntegerMax,
kIntegerMin,
kIntegerMax,
kIntegerMin},
INTERVAL_YEAR_MONTH());
assertExpression("c0 * c1", intervalYmVector, doubleVector, expectedYm);
assertExpression("c1 * c0", intervalYmVector, doubleVector, expectedYm);
}

TEST_F(ArithmeticTest, mod) {
Expand Down

0 comments on commit d1bac96

Please sign in to comment.