Skip to content

Commit

Permalink
Fix NaN handling in comparison functions (facebookincubator#10165)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#10165

This change ensures that comparison functions treat all NaN bit
representations as equal and greater than infinity.

Summary of changes:
Currently, numerical primitive types, including floating points, have
an optimized SIMD implementation. However, this relies on low-level
instructions that follow IEEE floating point semantics, where NaN
compared to any other number, including itself, returns false. There
is no trivial way to implement NaN checking and handling using SIMD,
so for now, we are falling back to the regular non-SIMD version that
employs comparators supporting NaN handling in the intended way.

Additionally, this change currently handles this by adding a special
case in the ComparisonSimdFunction vector function instead of
registering the SimpleFunction for floating types. This is because
registering an additional/different function to handle floating types
is causing a regression in an internal integration which is still under
investigation.

Reviewed By: kgpai

Differential Revision: D58471982
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed Jul 3, 2024
1 parent 4a0ce3e commit 9ab6c70
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 53 deletions.
67 changes: 50 additions & 17 deletions velox/functions/prestosql/Comparisons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@

namespace facebook::velox::functions {

using Eq = std::equal_to<>;
using Neq = std::not_equal_to<>;
using Lt = std::less<>;
using Lte = std::less_equal<>;
using Gt = std::greater<>;
using Gte = std::greater_equal<>;

namespace {

/// This class implements comparison for vectors of primitive types using SIMD.
Expand Down Expand Up @@ -94,6 +101,29 @@ struct SimdComparator {
}
}

template <typename T>
inline bool compare(T& l, T& r) const {
if constexpr (std::is_floating_point_v<T>) {
bool filtered = false;
if constexpr (std::is_same_v<ComparisonOp, Eq>) {
filtered = util::floating_point::NaNAwareEquals<T>{}(l, r);
} else if constexpr (std::is_same_v<ComparisonOp, Neq>) {
filtered = !util::floating_point::NaNAwareEquals<T>{}(l, r);
} else if constexpr (std::is_same_v<ComparisonOp, Lt>) {
filtered = util::floating_point::NaNAwareLessThan<T>{}(l, r);
} else if constexpr (std::is_same_v<ComparisonOp, Lte>) {
filtered = util::floating_point::NaNAwareLessThanEqual<T>{}(l, r);
} else if constexpr (std::is_same_v<ComparisonOp, Gt>) {
filtered = util::floating_point::NaNAwareGreaterThan<T>{}(l, r);
} else if constexpr (std::is_same_v<ComparisonOp, Gte>) {
filtered = util::floating_point::NaNAwareGreaterThanEqual<T>{}(l, r);
}
return filtered;
} else {
return ComparisonOp()(l, r);
}
}

template <
TypeKind kind,
typename std::enable_if_t<
Expand All @@ -113,24 +143,27 @@ struct SimdComparator {
auto resultVector = result->asUnchecked<FlatVector<bool>>();
auto rawResult = resultVector->mutableRawValues<uint8_t>();

auto isSimdizable = (lhs.isConstantEncoding() || lhs.isFlatEncoding()) &&
bool isSimdizable = (lhs.isConstantEncoding() || lhs.isFlatEncoding()) &&
(rhs.isConstantEncoding() || rhs.isFlatEncoding()) &&
rows.isAllSelected();

if (!isSimdizable || std::is_same_v<T, int128_t>) {
static const bool isTypeNotSupported =
std::is_same_v<T, int128_t> || std::is_floating_point_v<T>;

if (!isSimdizable || isTypeNotSupported) {
exec::LocalDecodedVector lhsDecoded(context, lhs, rows);
exec::LocalDecodedVector rhsDecoded(context, rhs, rows);

context.template applyToSelectedNoThrow(rows, [&](auto row) {
auto l = lhsDecoded->template valueAt<T>(row);
auto r = rhsDecoded->template valueAt<T>(row);
auto filtered = ComparisonOp()(l, r);
auto filtered = compare(l, r);
resultVector->set(row, filtered);
});
return;
}

if constexpr (!std::is_same_v<T, int128_t>) {
if constexpr (!isTypeNotSupported) {
if (lhs.isConstantEncoding() && rhs.isConstantEncoding()) {
auto l = lhs.asUnchecked<ConstantVector<T>>()->valueAt(0);
auto r = rhs.asUnchecked<ConstantVector<T>>()->valueAt(0);
Expand Down Expand Up @@ -242,7 +275,7 @@ class ComparisonSimdFunction : public exec::VectorFunction {
}

exec::FunctionCanonicalName getCanonicalName() const override {
return std::is_same_v<ComparisonOp, std::less<>>
return std::is_same_v<ComparisonOp, Lt>
? exec::FunctionCanonicalName::kLt
: exec::FunctionCanonicalName::kUnknown;
}
Expand All @@ -252,32 +285,32 @@ class ComparisonSimdFunction : public exec::VectorFunction {

VELOX_DECLARE_VECTOR_FUNCTION(
udf_simd_comparison_eq,
(ComparisonSimdFunction<std::equal_to<>>::signatures()),
(std::make_unique<ComparisonSimdFunction<std::equal_to<>>>()));
(ComparisonSimdFunction<Eq>::signatures()),
(std::make_unique<ComparisonSimdFunction<Eq>>()));

VELOX_DECLARE_VECTOR_FUNCTION(
udf_simd_comparison_neq,
(ComparisonSimdFunction<std::not_equal_to<>>::signatures()),
(std::make_unique<ComparisonSimdFunction<std::not_equal_to<>>>()));
(ComparisonSimdFunction<Neq>::signatures()),
(std::make_unique<ComparisonSimdFunction<Neq>>()));

VELOX_DECLARE_VECTOR_FUNCTION(
udf_simd_comparison_lt,
(ComparisonSimdFunction<std::less<>>::signatures()),
(std::make_unique<ComparisonSimdFunction<std::less<>>>()));
(ComparisonSimdFunction<Lt>::signatures()),
(std::make_unique<ComparisonSimdFunction<Lt>>()));

VELOX_DECLARE_VECTOR_FUNCTION(
udf_simd_comparison_gt,
(ComparisonSimdFunction<std::greater<>>::signatures()),
(std::make_unique<ComparisonSimdFunction<std::greater<>>>()));
(ComparisonSimdFunction<Gt>::signatures()),
(std::make_unique<ComparisonSimdFunction<Gt>>()));

VELOX_DECLARE_VECTOR_FUNCTION(
udf_simd_comparison_lte,
(ComparisonSimdFunction<std::less_equal<>>::signatures()),
(std::make_unique<ComparisonSimdFunction<std::less_equal<>>>()));
(ComparisonSimdFunction<Lte>::signatures()),
(std::make_unique<ComparisonSimdFunction<Lte>>()));

VELOX_DECLARE_VECTOR_FUNCTION(
udf_simd_comparison_gte,
(ComparisonSimdFunction<std::greater_equal<>>::signatures()),
(std::make_unique<ComparisonSimdFunction<std::greater_equal<>>>()));
(ComparisonSimdFunction<Gte>::signatures()),
(std::make_unique<ComparisonSimdFunction<Gte>>()));

} // namespace facebook::velox::functions
57 changes: 44 additions & 13 deletions velox/functions/prestosql/Comparisons.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,23 @@
#include "velox/common/base/CompareFlags.h"
#include "velox/functions/Macros.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/type/FloatingPointUtil.h"

namespace facebook::velox::functions {

#define VELOX_GEN_BINARY_EXPR(Name, Expr, TResult) \
template <typename T> \
struct Name { \
VELOX_DEFINE_FUNCTION_TYPES(T); \
template <typename TInput> \
FOLLY_ALWAYS_INLINE void \
call(TResult& result, const TInput& lhs, const TInput& rhs) { \
result = (Expr); \
} \
#define VELOX_GEN_BINARY_EXPR(Name, Expr, ExprForFloats) \
template <typename T> \
struct Name { \
VELOX_DEFINE_FUNCTION_TYPES(T); \
template <typename TInput> \
FOLLY_ALWAYS_INLINE void \
call(bool& result, const TInput& lhs, const TInput& rhs) { \
if constexpr (std::is_floating_point_v<TInput>) { \
result = (ExprForFloats); \
return; \
} \
result = (Expr); \
} \
};

#define VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(Name, tsExpr, TResult) \
Expand All @@ -44,10 +49,22 @@ namespace facebook::velox::functions {
} \
};

VELOX_GEN_BINARY_EXPR(LtFunction, lhs < rhs, bool);
VELOX_GEN_BINARY_EXPR(GtFunction, lhs > rhs, bool);
VELOX_GEN_BINARY_EXPR(LteFunction, lhs <= rhs, bool);
VELOX_GEN_BINARY_EXPR(GteFunction, lhs >= rhs, bool);
VELOX_GEN_BINARY_EXPR(
LtFunction,
lhs < rhs,
util::floating_point::NaNAwareLessThan<TInput>{}(lhs, rhs));
VELOX_GEN_BINARY_EXPR(
GtFunction,
lhs > rhs,
util::floating_point::NaNAwareGreaterThan<TInput>{}(lhs, rhs));
VELOX_GEN_BINARY_EXPR(
LteFunction,
lhs <= rhs,
util::floating_point::NaNAwareLessThanEqual<TInput>{}(lhs, rhs));
VELOX_GEN_BINARY_EXPR(
GteFunction,
lhs >= rhs,
util::floating_point::NaNAwareGreaterThanEqual<TInput>{}(lhs, rhs));

VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(
LtFunction,
Expand Down Expand Up @@ -99,6 +116,10 @@ struct EqFunction {
// Used for primitive inputs.
template <typename TInput>
void call(bool& out, const TInput& lhs, const TInput& rhs) {
if constexpr (std::is_floating_point_v<TInput>) {
out = util::floating_point::NaNAwareEquals<TInput>{}(lhs, rhs);
return;
}
out = (lhs == rhs);
}

Expand Down Expand Up @@ -138,6 +159,10 @@ struct NeqFunction {
// Used for primitive inputs.
template <typename TInput>
void call(bool& out, const TInput& lhs, const TInput& rhs) {
if constexpr (std::is_floating_point_v<TInput>) {
out = !util::floating_point::NaNAwareEquals<TInput>{}(lhs, rhs);
return;
}
out = (lhs != rhs);
}

Expand Down Expand Up @@ -172,6 +197,12 @@ struct BetweenFunction {
template <typename T>
FOLLY_ALWAYS_INLINE void
call(bool& result, const T& value, const T& low, const T& high) {
if constexpr (std::is_floating_point_v<T>) {
result =
util::floating_point::NaNAwareGreaterThanEqual<T>{}(value, low) &&
util::floating_point::NaNAwareLessThanEqual<T>{}(value, high);
return;
}
result = value >= low && value <= high;
}
};
Expand Down
120 changes: 97 additions & 23 deletions velox/functions/prestosql/tests/ComparisonsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include <string>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/Udf.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/prestosql/Comparisons.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"

using namespace facebook::velox;
Expand All @@ -34,6 +36,16 @@ class ComparisonsTest : public functions::test::FunctionBaseTest {
auto actual = evaluate(exprStr, makeRowVector(input));
test::assertEqualVectors(expectedResult, actual);
}

void registerSimpleComparisonFunctions() {
using namespace facebook::velox::functions;
registerBinaryScalar<EqFunction, bool>({"simple_eq"});
registerBinaryScalar<NeqFunction, bool>({"simple_neq"});
registerBinaryScalar<LtFunction, bool>({"simple_lt"});
registerBinaryScalar<LteFunction, bool>({"simple_lte"});
registerBinaryScalar<GtFunction, bool>({"simple_gt"});
registerBinaryScalar<GteFunction, bool>({"simple_gte"});
}
};

TEST_F(ComparisonsTest, between) {
Expand Down Expand Up @@ -641,6 +653,91 @@ TEST_F(ComparisonsTest, overflowTest) {
}
}

TEST_F(ComparisonsTest, nanComparison) {
registerSimpleComparisonFunctions();
static const auto kNaN = std::numeric_limits<double>::quiet_NaN();
static const auto kSNaN = std::numeric_limits<double>::signaling_NaN();
static const auto kInf = std::numeric_limits<double>::infinity();

auto testNaN =
[&](std::string prefix, RowVectorPtr rowVector, bool primitiveInput) {
auto eval = [&](const std::string& expr,
const std::string& lhs,
const std::string& rhs) {
return evaluate<SimpleVector<bool>>(
fmt::format("{}({}, {})", expr, lhs, rhs), rowVector);
};

auto allFalse = makeFlatVector<bool>({false, false});
auto allTrue = makeFlatVector<bool>({true, true});

// NaN compared with NaN (multiple binary representations)
test::assertEqualVectors(eval(prefix + "eq", "c0", "c1"), allTrue);
test::assertEqualVectors(eval(prefix + "neq", "c0", "c1"), allFalse);
if (primitiveInput) {
test::assertEqualVectors(eval(prefix + "lt", "c0", "c1"), allFalse);
test::assertEqualVectors(eval(prefix + "gt", "c0", "c1"), allFalse);
test::assertEqualVectors(eval(prefix + "lte", "c0", "c1"), allTrue);
test::assertEqualVectors(eval(prefix + "gte", "c0", "c1"), allTrue);
// NaN between Infinity and NaN
test::assertEqualVectors(
evaluate<SimpleVector<bool>>("c0 BETWEEN c2 and c1", rowVector),
allTrue);
// NaN distinct from NaN
test::assertEqualVectors(
evaluate<SimpleVector<bool>>("c0 IS DISTINCT FROM c1", rowVector),
allFalse);
}

// NaN compared with Inf
test::assertEqualVectors(eval(prefix + "eq", "c0", "c2"), allFalse);
test::assertEqualVectors(eval(prefix + "neq", "c0", "c2"), allTrue);
if (primitiveInput) {
test::assertEqualVectors(eval(prefix + "lt", "c0", "c2"), allFalse);
test::assertEqualVectors(eval(prefix + "gt", "c0", "c2"), allTrue);
test::assertEqualVectors(eval(prefix + "lte", "c0", "c2"), allFalse);
test::assertEqualVectors(eval(prefix + "gte", "c0", "c2"), allTrue);
// NaN between 0 and Infinity
test::assertEqualVectors(
evaluate<SimpleVector<bool>>(
"c0 BETWEEN cast(0 as double) and c2", rowVector),
allFalse);
// NaN distinct from Infinity
test::assertEqualVectors(
evaluate<SimpleVector<bool>>("c0 IS DISTINCT FROM c2", rowVector),
allTrue);
}
};

// Primitive type input
auto input = makeRowVector(
{makeFlatVector<double>({kNaN, kSNaN}),
makeFlatVector<double>({kNaN, kNaN}),
makeFlatVector<double>({kInf, kInf})});
// Test the Vector function ComparisonSimdFunction.
testNaN("", input, true);
// Test the Simple functions.
testNaN("simple_", input, true);

// Complex type input
input = makeRowVector({
makeRowVector({
makeFlatVector<double>({kNaN, kSNaN}),
makeFlatVector<int32_t>({1, 1}),
}),
makeRowVector({
makeFlatVector<double>({kNaN, kNaN}),
makeFlatVector<int32_t>({1, 1}),
}),
makeRowVector({
makeFlatVector<double>({kInf, kInf}),
makeFlatVector<int32_t>({1, 1}),
}),
});
// Note: Complex comparison functions are only registered as simple functions.
testNaN("", input, false);
}

namespace {
template <typename Tp, typename Op, const char* fnName>
struct ComparisonTypeOp {
Expand Down Expand Up @@ -766,29 +863,6 @@ class SimdComparisonsTest : public functions::test::FunctionBaseTest {
rhsVector.begin(), rhsVector.end(), std::numeric_limits<T>::min());

testVectorComparison(lhsVector, rhsVector);

// Add tests against Nan and other edge cases.
if constexpr (std::is_floating_point_v<T>) {
lhsVector = std::vector<T>(47);
rhsVector = std::vector<T>(47);

std::fill(
lhsVector.begin(),
lhsVector.end(),
std::numeric_limits<T>::signaling_NaN());
std::fill(
rhsVector.begin(),
rhsVector.end(),
std::numeric_limits<T>::signaling_NaN());
testVectorComparison(lhsVector, rhsVector);

std::fill(
lhsVector.begin(),
lhsVector.end(),
std::numeric_limits<T>::signaling_NaN());
std::fill(rhsVector.begin(), rhsVector.end(), 1);
testVectorComparison(lhsVector, rhsVector);
}
}

void testDictionary() {
Expand Down
Loading

0 comments on commit 9ab6c70

Please sign in to comment.