From d7c5383ead7fe7b8c46893404487bc35a8b5d95c Mon Sep 17 00:00:00 2001 From: Yenda Li Date: Mon, 11 Nov 2024 09:42:52 -0800 Subject: [PATCH] Add comparison support for RowView (#11499) Summary: Support comparison for RowView. This will allow us to compare IPPrefix which has an underlying type of Row. When doing between and other comparison operations, RowView's comparisons are not implemented. We can extend RowView's comparison similar to GenericView. I introduce a base class which does the comparisons. The classes which implement the base class can specialize how they want to implement compare. For RowView, I iterate through the underlying tuple 1 by 1 until we find the first match where the underlying RowVector returns a non-zero comparison. Differential Revision: D65700875 --- velox/expression/ComplexViewTypes.h | 119 ++++++++++++++------ velox/expression/tests/RowViewTest.cpp | 144 ++++++++++++++++++++++++- 2 files changed, 225 insertions(+), 38 deletions(-) diff --git a/velox/expression/ComplexViewTypes.h b/velox/expression/ComplexViewTypes.h index 413fb2157b48b..398d91d0ee995 100644 --- a/velox/expression/ComplexViewTypes.h +++ b/velox/expression/ComplexViewTypes.h @@ -31,6 +31,59 @@ namespace facebook::velox::exec { +/// Base class for views that need comparison. Default comparison is forbidden +/// and this class requires specialization. For now, defaulting the +/// == flag to be kNullAsValue and other comparison flag to be +/// kNullAsIndeterminate. Can adjust in the future to be more configurable. +template +struct BaseView { + std::optional compare( + const T& /*other*/, + const CompareFlags /*flags*/) const { + VELOX_UNSUPPORTED("Must provide specialization"); + } + + bool operator==(const T& other) const { + static constexpr auto kEqualValueAtFlags = + CompareFlags::equality(CompareFlags::NullHandlingMode::kNullAsValue); + return this->compareOrThrow(other, kEqualValueAtFlags) == 0; + } + + bool operator<(const T& other) const { + return this->compareOrThrow(other) < 0; + } + + bool operator<=(const T& other) const { + return this->compareOrThrow(other) <= 0; + } + + bool operator>(const T& other) const { + return this->compareOrThrow(other) > 0; + } + + bool operator>=(const T& other) const { + return this->compareOrThrow(other) >= 0; + } + + bool operator!=(const T& other) const { + return this->compareOrThrow(other) != 0; + } + + private: + int64_t compareOrThrow( + const T& other, + CompareFlags flags = CompareFlags{ + .nullHandlingMode = + CompareFlags::NullHandlingMode::kNullAsIndeterminate}) const { + auto result = static_cast(this)->compare(other, flags); + // Will throw if it encounters null elements before result is determined. + VELOX_DCHECK( + result.has_value(), + "Compare should have thrown when null is encountered in child."); + return result.value(); + } +}; + template struct VectorReader; @@ -927,7 +980,7 @@ class DynamicRowView { }; template -class RowView { +class RowView : public BaseView> { using reader_t = std::tuple>...>; using types = std::tuple; @@ -967,11 +1020,40 @@ class RowView { return result; } + std::optional compare(const RowView& other, const CompareFlags flags) + const { + return compareImpl(other, flags); + } + private: void initialize() { initializeImpl(std::index_sequence_for()); } + template + std::optional compareImpl( + const RowView& other, + const CompareFlags flags) const { + if constexpr (Is < sizeof...(T)) { + auto result = std::get(*childReaders_) + ->baseVector() + ->compare( + std::get(*other.childReaders_)->baseVector(), + offset_, + other.offset_, + flags); + if (!result.has_value()) { + return std::nullopt; + } + if (result.value() != 0) { + return result.value(); + } + + return compareImpl(other, flags); + } + return 0; + } + using children_types = std::tuple; template void materializeImpl(materialize_t& result, std::index_sequence) @@ -1068,7 +1150,7 @@ struct AllGenericExceptTop> { } }; -class GenericView { +class GenericView : public BaseView { public: GenericView( const DecodedVector& decoded, @@ -1092,39 +1174,6 @@ class GenericView { return decoded_.base(); } - bool operator==(const GenericView& other) const { - return decoded_.base()->equalValueAt( - other.decoded_.base(), decodedIndex(), other.decodedIndex()); - } - - int64_t compareOrThrow(const GenericView& other) const { - static constexpr CompareFlags kFlags = { - .nullHandlingMode = - CompareFlags::NullHandlingMode::kNullAsIndeterminate}; - std::optional result = this->compare(other, kFlags); - // Will throw if it encounters null elements before result is determined. - VELOX_DCHECK( - result.has_value(), - "Compare should have thrown when null is encountered in child."); - return result.value(); - } - - bool operator<(const GenericView& other) const { - return compareOrThrow(other) < 0; - } - - bool operator<=(const GenericView& other) const { - return compareOrThrow(other) <= 0; - } - - bool operator>(const GenericView& other) const { - return compareOrThrow(other) > 0; - } - - bool operator>=(const GenericView& other) const { - return compareOrThrow(other) >= 0; - } - vector_size_t decodedIndex() const { return decoded_.index(index_); } diff --git a/velox/expression/tests/RowViewTest.cpp b/velox/expression/tests/RowViewTest.cpp index edc5ed0720fed..b5310e6a76125 100644 --- a/velox/expression/tests/RowViewTest.cpp +++ b/velox/expression/tests/RowViewTest.cpp @@ -17,13 +17,17 @@ #include #include +#include "velox/common/base/tests/GTestUtils.h" #include "velox/expression/VectorReaders.h" #include "velox/functions/Udf.h" +#include "velox/functions/prestosql/Comparisons.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" namespace { using namespace facebook::velox; +using namespace facebook::velox::functions; +using namespace facebook::velox::test; DecodedVector* decode(DecodedVector& decoder, const BaseVector& vector) { SelectivityVector rows(vector.size()); @@ -145,6 +149,133 @@ class RowViewTest : public functions::test::FunctionBaseTest { } } } + + void compareTest() { + auto rowVector1 = makeRowVector( + {makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({1.0})}); + auto rowVector2 = makeRowVector( + {makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({2.0})}); + { + DecodedVector decoded1; + DecodedVector decoded2; + + exec::VectorReader> reader1( + decode(decoded1, *rowVector1)); + exec::VectorReader> reader2( + decode(decoded2, *rowVector2)); + + ASSERT_TRUE(reader1.isSet(0)); + ASSERT_TRUE(reader2.isSet(0)); + auto l = read(reader1, 0); + auto r = read(reader2, 0); + // Default flag for all operators other than `==` is kNullAsIndeterminate + VELOX_ASSERT_THROW(r < l, "Ordering nulls is not supported"); + VELOX_ASSERT_THROW(r <= l, "Ordering nulls is not supported"); + VELOX_ASSERT_THROW(r > l, "Ordering nulls is not supported"); + VELOX_ASSERT_THROW(r >= l, "Ordering nulls is not supported"); + + // Default flag for `==` is kNullAsValue + ASSERT_FALSE(r == l); + + // Test we can pass in a flag to change the behavior for compare + ASSERT_LT( + l.compare( + r, + CompareFlags::equality( + CompareFlags::NullHandlingMode::kNullAsValue)), + 0); + } + + // Test indeterminate ROW = [null, 2.0] against + // [null, 2.0] is indeterminate + { + auto rowVector = vectorMaker_.rowVector( + {BaseVector::createNullConstant( + ROW({{"a", INTEGER()}}), 1, pool_.get()), + makeNullableFlatVector({1.0})}); + + DecodedVector decoded1; + exec::VectorReader> reader1( + decode(decoded1, *rowVector1)); + ASSERT_TRUE(reader1.isSet(0)); + auto l = read(reader1, 0); + auto flags = CompareFlags::equality( + CompareFlags::NullHandlingMode::kNullAsIndeterminate); + ASSERT_EQ(l.compare(l, flags), kIndeterminate); + } + } + + void e2eComparisonTest() { + auto lhs = makeRowVector( + {makeFlatVector({1, 2, 3, 4, 5, 6}), + makeFlatVector({1.0, 2.0, 3.0, 4.0, 6.0, 0.0})}); + auto rhs = makeRowVector( + {makeNullableFlatVector({5, 4, 3, 4, 5, 6}), + makeFlatVector({2.0, 2.0, 3.0, 4.0, 6.0, 1.1})}); + + registerFunction< + EqFunction, + bool, + Row, + Row>({"row_eq"}); + auto result = + evaluate>("row_eq(c0, c1)", makeRowVector({lhs, rhs})); + assertEqualVectors( + makeFlatVector({false, false, true, true, true, false}), result); + + registerFunction< + NeqFunction, + bool, + Row, + Row>({"row_neq"}); + result = evaluate>( + "row_neq(c0, c1)", makeRowVector({lhs, rhs})); + assertEqualVectors( + makeFlatVector({true, true, false, false, false, true}), result); + + registerFunction< + LtFunction, + bool, + Row, + Row>({"row_lt"}); + result = + evaluate>("row_lt(c0, c1)", makeRowVector({lhs, rhs})); + assertEqualVectors( + makeFlatVector({true, true, false, false, false, true}), result); + + registerFunction< + GtFunction, + bool, + Row, + Row>({"row_gt"}); + result = + evaluate>("row_gt(c0, c1)", makeRowVector({lhs, rhs})); + assertEqualVectors( + makeFlatVector({false, false, false, false, false, false}), + result); + + registerFunction< + LteFunction, + bool, + Row, + Row>({"row_lte"}); + result = evaluate>( + "row_lte(c0, c1)", makeRowVector({lhs, rhs})); + assertEqualVectors( + makeFlatVector({true, true, true, true, true, true}), result); + + registerFunction< + GteFunction, + bool, + Row, + Row>({"row_gte"}); + result = evaluate>( + "row_gte(c0, c1)", makeRowVector({lhs, rhs})); + assertEqualVectors( + makeFlatVector({false, false, true, true, true, false}), result); + } }; class NullableRowViewTest : public RowViewTest {}; @@ -188,6 +319,13 @@ TEST_F(NullFreeRowViewTest, materialize) { 1, "hi", {1, 2, 3}}; ASSERT_EQ(reader.readNullFree(0).materialize(), expected); } +TEST_F(NullFreeRowViewTest, compare) { + compareTest(); +} + +TEST_F(NullFreeRowViewTest, e2eCompare) { + e2eComparisonTest(); +} TEST_F(NullableRowViewTest, materialize) { auto result = evaluate( @@ -299,16 +437,16 @@ TEST_F(DynamicRowViewTest, castToDynamicRowInFunction) { // Input is not struct. auto result = evaluate("struct_width(c0)", makeRowVector({flatVector})); - test::assertEqualVectors(makeFlatVector({0, 0}), result); + assertEqualVectors(makeFlatVector({0, 0}), result); result = evaluate( "struct_width(c0)", makeRowVector({makeRowVector({flatVector})})); - test::assertEqualVectors(makeFlatVector({1, 1}), result); + assertEqualVectors(makeFlatVector({1, 1}), result); result = evaluate( "struct_width(c0)", makeRowVector({makeRowVector({flatVector, flatVector})})); - test::assertEqualVectors(makeFlatVector({2, 2}), result); + assertEqualVectors(makeFlatVector({2, 2}), result); } } } // namespace