diff --git a/velox/expression/ComplexViewTypes.h b/velox/expression/ComplexViewTypes.h index 413fb2157b48b..398d91d0ee995 100644 --- a/velox/expression/ComplexViewTypes.h +++ b/velox/expression/ComplexViewTypes.h @@ -31,6 +31,59 @@ namespace facebook::velox::exec { +/// Base class for views that need comparison. Default comparison is forbidden +/// and this class requires specialization. For now, defaulting the +/// == flag to be kNullAsValue and other comparison flag to be +/// kNullAsIndeterminate. Can adjust in the future to be more configurable. +template +struct BaseView { + std::optional compare( + const T& /*other*/, + const CompareFlags /*flags*/) const { + VELOX_UNSUPPORTED("Must provide specialization"); + } + + bool operator==(const T& other) const { + static constexpr auto kEqualValueAtFlags = + CompareFlags::equality(CompareFlags::NullHandlingMode::kNullAsValue); + return this->compareOrThrow(other, kEqualValueAtFlags) == 0; + } + + bool operator<(const T& other) const { + return this->compareOrThrow(other) < 0; + } + + bool operator<=(const T& other) const { + return this->compareOrThrow(other) <= 0; + } + + bool operator>(const T& other) const { + return this->compareOrThrow(other) > 0; + } + + bool operator>=(const T& other) const { + return this->compareOrThrow(other) >= 0; + } + + bool operator!=(const T& other) const { + return this->compareOrThrow(other) != 0; + } + + private: + int64_t compareOrThrow( + const T& other, + CompareFlags flags = CompareFlags{ + .nullHandlingMode = + CompareFlags::NullHandlingMode::kNullAsIndeterminate}) const { + auto result = static_cast(this)->compare(other, flags); + // Will throw if it encounters null elements before result is determined. + VELOX_DCHECK( + result.has_value(), + "Compare should have thrown when null is encountered in child."); + return result.value(); + } +}; + template struct VectorReader; @@ -927,7 +980,7 @@ class DynamicRowView { }; template -class RowView { +class RowView : public BaseView> { using reader_t = std::tuple>...>; using types = std::tuple; @@ -967,11 +1020,40 @@ class RowView { return result; } + std::optional compare(const RowView& other, const CompareFlags flags) + const { + return compareImpl(other, flags); + } + private: void initialize() { initializeImpl(std::index_sequence_for()); } + template + std::optional compareImpl( + const RowView& other, + const CompareFlags flags) const { + if constexpr (Is < sizeof...(T)) { + auto result = std::get(*childReaders_) + ->baseVector() + ->compare( + std::get(*other.childReaders_)->baseVector(), + offset_, + other.offset_, + flags); + if (!result.has_value()) { + return std::nullopt; + } + if (result.value() != 0) { + return result.value(); + } + + return compareImpl(other, flags); + } + return 0; + } + using children_types = std::tuple; template void materializeImpl(materialize_t& result, std::index_sequence) @@ -1068,7 +1150,7 @@ struct AllGenericExceptTop> { } }; -class GenericView { +class GenericView : public BaseView { public: GenericView( const DecodedVector& decoded, @@ -1092,39 +1174,6 @@ class GenericView { return decoded_.base(); } - bool operator==(const GenericView& other) const { - return decoded_.base()->equalValueAt( - other.decoded_.base(), decodedIndex(), other.decodedIndex()); - } - - int64_t compareOrThrow(const GenericView& other) const { - static constexpr CompareFlags kFlags = { - .nullHandlingMode = - CompareFlags::NullHandlingMode::kNullAsIndeterminate}; - std::optional result = this->compare(other, kFlags); - // Will throw if it encounters null elements before result is determined. - VELOX_DCHECK( - result.has_value(), - "Compare should have thrown when null is encountered in child."); - return result.value(); - } - - bool operator<(const GenericView& other) const { - return compareOrThrow(other) < 0; - } - - bool operator<=(const GenericView& other) const { - return compareOrThrow(other) <= 0; - } - - bool operator>(const GenericView& other) const { - return compareOrThrow(other) > 0; - } - - bool operator>=(const GenericView& other) const { - return compareOrThrow(other) >= 0; - } - vector_size_t decodedIndex() const { return decoded_.index(index_); } diff --git a/velox/expression/tests/RowViewTest.cpp b/velox/expression/tests/RowViewTest.cpp index edc5ed0720fed..b5310e6a76125 100644 --- a/velox/expression/tests/RowViewTest.cpp +++ b/velox/expression/tests/RowViewTest.cpp @@ -17,13 +17,17 @@ #include #include +#include "velox/common/base/tests/GTestUtils.h" #include "velox/expression/VectorReaders.h" #include "velox/functions/Udf.h" +#include "velox/functions/prestosql/Comparisons.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" namespace { using namespace facebook::velox; +using namespace facebook::velox::functions; +using namespace facebook::velox::test; DecodedVector* decode(DecodedVector& decoder, const BaseVector& vector) { SelectivityVector rows(vector.size()); @@ -145,6 +149,133 @@ class RowViewTest : public functions::test::FunctionBaseTest { } } } + + void compareTest() { + auto rowVector1 = makeRowVector( + {makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({1.0})}); + auto rowVector2 = makeRowVector( + {makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({2.0})}); + { + DecodedVector decoded1; + DecodedVector decoded2; + + exec::VectorReader> reader1( + decode(decoded1, *rowVector1)); + exec::VectorReader> reader2( + decode(decoded2, *rowVector2)); + + ASSERT_TRUE(reader1.isSet(0)); + ASSERT_TRUE(reader2.isSet(0)); + auto l = read(reader1, 0); + auto r = read(reader2, 0); + // Default flag for all operators other than `==` is kNullAsIndeterminate + VELOX_ASSERT_THROW(r < l, "Ordering nulls is not supported"); + VELOX_ASSERT_THROW(r <= l, "Ordering nulls is not supported"); + VELOX_ASSERT_THROW(r > l, "Ordering nulls is not supported"); + VELOX_ASSERT_THROW(r >= l, "Ordering nulls is not supported"); + + // Default flag for `==` is kNullAsValue + ASSERT_FALSE(r == l); + + // Test we can pass in a flag to change the behavior for compare + ASSERT_LT( + l.compare( + r, + CompareFlags::equality( + CompareFlags::NullHandlingMode::kNullAsValue)), + 0); + } + + // Test indeterminate ROW = [null, 2.0] against + // [null, 2.0] is indeterminate + { + auto rowVector = vectorMaker_.rowVector( + {BaseVector::createNullConstant( + ROW({{"a", INTEGER()}}), 1, pool_.get()), + makeNullableFlatVector({1.0})}); + + DecodedVector decoded1; + exec::VectorReader> reader1( + decode(decoded1, *rowVector1)); + ASSERT_TRUE(reader1.isSet(0)); + auto l = read(reader1, 0); + auto flags = CompareFlags::equality( + CompareFlags::NullHandlingMode::kNullAsIndeterminate); + ASSERT_EQ(l.compare(l, flags), kIndeterminate); + } + } + + void e2eComparisonTest() { + auto lhs = makeRowVector( + {makeFlatVector({1, 2, 3, 4, 5, 6}), + makeFlatVector({1.0, 2.0, 3.0, 4.0, 6.0, 0.0})}); + auto rhs = makeRowVector( + {makeNullableFlatVector({5, 4, 3, 4, 5, 6}), + makeFlatVector({2.0, 2.0, 3.0, 4.0, 6.0, 1.1})}); + + registerFunction< + EqFunction, + bool, + Row, + Row>({"row_eq"}); + auto result = + evaluate>("row_eq(c0, c1)", makeRowVector({lhs, rhs})); + assertEqualVectors( + makeFlatVector({false, false, true, true, true, false}), result); + + registerFunction< + NeqFunction, + bool, + Row, + Row>({"row_neq"}); + result = evaluate>( + "row_neq(c0, c1)", makeRowVector({lhs, rhs})); + assertEqualVectors( + makeFlatVector({true, true, false, false, false, true}), result); + + registerFunction< + LtFunction, + bool, + Row, + Row>({"row_lt"}); + result = + evaluate>("row_lt(c0, c1)", makeRowVector({lhs, rhs})); + assertEqualVectors( + makeFlatVector({true, true, false, false, false, true}), result); + + registerFunction< + GtFunction, + bool, + Row, + Row>({"row_gt"}); + result = + evaluate>("row_gt(c0, c1)", makeRowVector({lhs, rhs})); + assertEqualVectors( + makeFlatVector({false, false, false, false, false, false}), + result); + + registerFunction< + LteFunction, + bool, + Row, + Row>({"row_lte"}); + result = evaluate>( + "row_lte(c0, c1)", makeRowVector({lhs, rhs})); + assertEqualVectors( + makeFlatVector({true, true, true, true, true, true}), result); + + registerFunction< + GteFunction, + bool, + Row, + Row>({"row_gte"}); + result = evaluate>( + "row_gte(c0, c1)", makeRowVector({lhs, rhs})); + assertEqualVectors( + makeFlatVector({false, false, true, true, true, false}), result); + } }; class NullableRowViewTest : public RowViewTest {}; @@ -188,6 +319,13 @@ TEST_F(NullFreeRowViewTest, materialize) { 1, "hi", {1, 2, 3}}; ASSERT_EQ(reader.readNullFree(0).materialize(), expected); } +TEST_F(NullFreeRowViewTest, compare) { + compareTest(); +} + +TEST_F(NullFreeRowViewTest, e2eCompare) { + e2eComparisonTest(); +} TEST_F(NullableRowViewTest, materialize) { auto result = evaluate( @@ -299,16 +437,16 @@ TEST_F(DynamicRowViewTest, castToDynamicRowInFunction) { // Input is not struct. auto result = evaluate("struct_width(c0)", makeRowVector({flatVector})); - test::assertEqualVectors(makeFlatVector({0, 0}), result); + assertEqualVectors(makeFlatVector({0, 0}), result); result = evaluate( "struct_width(c0)", makeRowVector({makeRowVector({flatVector})})); - test::assertEqualVectors(makeFlatVector({1, 1}), result); + assertEqualVectors(makeFlatVector({1, 1}), result); result = evaluate( "struct_width(c0)", makeRowVector({makeRowVector({flatVector, flatVector})})); - test::assertEqualVectors(makeFlatVector({2, 2}), result); + assertEqualVectors(makeFlatVector({2, 2}), result); } } } // namespace