From 12ece8ab5e5ef176b4d79cb6b72797121740ddeb Mon Sep 17 00:00:00 2001 From: ZPlutoY <38772646+Pluto-Zy@users.noreply.github.com> Date: Mon, 7 Aug 2023 19:56:48 +0800 Subject: [PATCH] variant: Implement the three-way comparison for monostate and variant. --- include/rust_enum/variant/comparison.hpp | 65 ++++++- include/rust_enum/variant/monostate.hpp | 10 + unittest/variant/CMakeLists.txt | 1 + unittest/variant/common.hpp | 7 + .../variant/comparison_three_way_test.cpp | 177 ++++++++++++++++++ unittest/variant/monostate_test.cpp | 10 + 6 files changed, 264 insertions(+), 6 deletions(-) create mode 100644 unittest/variant/comparison_three_way_test.cpp diff --git a/include/rust_enum/variant/comparison.hpp b/include/rust_enum/variant/comparison.hpp index 8644ebd..ccd88ac 100644 --- a/include/rust_enum/variant/comparison.hpp +++ b/include/rust_enum/variant/comparison.hpp @@ -10,6 +10,10 @@ #include #include +#ifdef USE_CXX20 + #include +#endif + namespace rust { namespace detail { template @@ -32,7 +36,8 @@ struct variant_comparison_impl { template CONSTEXPR17 auto operator==(const variant& lhs, const variant& rhs) noexcept( - std::conjunction, const Tys&, const Tys&>...>::value + std::conjunction< + std::is_nothrow_invocable_r, const Tys&, const Tys&>...>::value ) -> bool { return lhs.index() == rhs.index() && (lhs.valueless_by_exception() @@ -45,7 +50,7 @@ CONSTEXPR17 auto operator==(const variant& lhs, const variant& r template CONSTEXPR17 auto operator!=(const variant& lhs, const variant& rhs) noexcept( std::conjunction< - std::is_nothrow_invocable, const Tys&, const Tys&>...>::value + std::is_nothrow_invocable_r, const Tys&, const Tys&>...>::value ) -> bool { return lhs.index() != rhs.index() || (!lhs.valueless_by_exception() @@ -57,7 +62,8 @@ CONSTEXPR17 auto operator!=(const variant& lhs, const variant& r template CONSTEXPR17 auto operator<(const variant& lhs, const variant& rhs) noexcept( - std::conjunction, const Tys&, const Tys&>...>::value + std::conjunction< + std::is_nothrow_invocable_r, const Tys&, const Tys&>...>::value ) -> bool { if (rhs.valueless_by_exception()) return false; @@ -73,7 +79,8 @@ CONSTEXPR17 auto operator<(const variant& lhs, const variant& rh template CONSTEXPR17 auto operator>(const variant& lhs, const variant& rhs) noexcept( - std::conjunction, const Tys&, const Tys&>...>::value + std::conjunction< + std::is_nothrow_invocable_r, const Tys&, const Tys&>...>::value ) -> bool { if (lhs.valueless_by_exception()) return false; @@ -90,7 +97,7 @@ CONSTEXPR17 auto operator>(const variant& lhs, const variant& rh template CONSTEXPR17 auto operator<=(const variant& lhs, const variant& rhs) noexcept( std::conjunction< - std::is_nothrow_invocable, const Tys&, const Tys&>...>::value + std::is_nothrow_invocable_r, const Tys&, const Tys&>...>::value ) -> bool { if (lhs.valueless_by_exception()) return true; @@ -107,7 +114,7 @@ CONSTEXPR17 auto operator<=(const variant& lhs, const variant& r template CONSTEXPR17 auto operator>=(const variant& lhs, const variant& rhs) noexcept( std::conjunction< - std::is_nothrow_invocable, const Tys&, const Tys&>...>::value + std::is_nothrow_invocable_r, const Tys&, const Tys&>...>::value ) -> bool { if (rhs.valueless_by_exception()) return true; @@ -120,6 +127,52 @@ CONSTEXPR17 auto operator>=(const variant& lhs, const variant& r rhs.storage() )); } + +#ifdef USE_CXX20 +namespace detail { +template +struct variant_three_way_comparison_impl { + const Variant& lhs; + + template + CONSTEXPR20 auto operator()(tagged_reference rhs) const -> Ret { + if constexpr (std::is_same::value) { + unreachable(); + } else { + return std::compare_three_way {}( + get_variant_tagged_content(lhs.storage()).forward_content().forward_value(), + rhs.forward_content().forward_value() + ); + } + } +}; +} // namespace detail + +template + requires(std::three_way_comparable && ...) +CONSTEXPR20 auto operator<=>(const variant& lhs, const variant& rhs) noexcept( + std::conjunction...>, + std::compare_three_way, + const Tys&, + const Tys&>...>::value +) -> std::common_comparison_category_t...> { + if (lhs.valueless_by_exception() && rhs.valueless_by_exception()) + return std::strong_ordering::equal; + if (lhs.valueless_by_exception()) + return std::strong_ordering::less; + if (rhs.valueless_by_exception()) + return std::strong_ordering::greater; + if (auto const c = lhs.index() <=> rhs.index(); c != 0) + return c; + + using ret_type = std::common_comparison_category_t...>; + return detail::tagged_visit( + detail::variant_three_way_comparison_impl, ret_type> { lhs }, + rhs.storage() + ); +} +#endif // USE_CXX20 } // namespace rust #endif // RUST_ENUM_INCLUDE_RUST_ENUM_VARIANT_COMPARISON_HPP diff --git a/include/rust_enum/variant/monostate.hpp b/include/rust_enum/variant/monostate.hpp index d8a1b5c..b8201f8 100644 --- a/include/rust_enum/variant/monostate.hpp +++ b/include/rust_enum/variant/monostate.hpp @@ -3,6 +3,10 @@ #include +#ifdef USE_CXX20 + #include +#endif + namespace rust { struct monostate { }; @@ -10,6 +14,11 @@ CONSTEXPR17 auto operator==(monostate, monostate) noexcept -> bool { return true; } +#ifdef USE_CXX20 +CONSTEXPR20 auto operator<=>(monostate, monostate) noexcept -> std::strong_ordering { + return std::strong_ordering::equal; +} +#else CONSTEXPR17 auto operator!=(monostate, monostate) noexcept -> bool { return false; } @@ -29,6 +38,7 @@ CONSTEXPR17 auto operator<=(monostate, monostate) noexcept -> bool { CONSTEXPR17 auto operator>=(monostate, monostate) noexcept -> bool { return true; } +#endif // USE_CXX20 } // namespace rust #endif // RUST_ENUM_INCLUDE_RUST_ENUM_VARIANT_MONOSTATE_HPP diff --git a/unittest/variant/CMakeLists.txt b/unittest/variant/CMakeLists.txt index c8bbeb0..d6d6fef 100644 --- a/unittest/variant/CMakeLists.txt +++ b/unittest/variant/CMakeLists.txt @@ -8,6 +8,7 @@ rust_enum_add_unittest(variant_test get_test.cpp holds_alternative_test.cpp comparison_test.cpp + comparison_three_way_test.cpp assignment_test/converting_assignment_test.cpp assignment_test/copy_assignment_test.cpp assignment_test/move_assignment_test.cpp diff --git a/unittest/variant/common.hpp b/unittest/variant/common.hpp index 86738c6..e909b5e 100644 --- a/unittest/variant/common.hpp +++ b/unittest/variant/common.hpp @@ -51,6 +51,13 @@ inline auto operator>=(const valueless_t&, const valueless_t&) -> bool { return false; } +#ifdef USE_CXX20 +inline auto operator<=>(const valueless_t&, const valueless_t&) -> std::weak_ordering { + ADD_FAILURE(); + return std::weak_ordering::equivalent; +} +#endif + template void make_valueless(Variant& v) { Variant valueless(std::in_place_type); diff --git a/unittest/variant/comparison_three_way_test.cpp b/unittest/variant/comparison_three_way_test.cpp new file mode 100644 index 0000000..a44f106 --- /dev/null +++ b/unittest/variant/comparison_three_way_test.cpp @@ -0,0 +1,177 @@ +#include "common.hpp" + +#ifdef USE_CXX20 + #include + +namespace rust { +namespace { +template +constexpr auto test_comparison_complete( // + const T1& lhs, + const T2& rhs, + std::partial_ordering order +) -> bool { + #define CHECK_BODY(op) \ + if ((lhs op rhs) != (order op 0)) \ + return false; \ + if ((rhs op lhs) != (0 op order)) \ + return false; + + CHECK_BODY(==) + CHECK_BODY(!=) + CHECK_BODY(<) + CHECK_BODY(>) + CHECK_BODY(<=) + CHECK_BODY(>=) + + return true; + #undef CHECK_BODY +} + +template +constexpr auto test_order(const T1& lhs, const T2& rhs, Order order) -> bool { + if ((lhs <=> rhs) != order) + return false; + return test_comparison_complete(lhs, rhs, order); +} + +TEST(VariantTestThreeWayComparison, Valueless) { + { + using v = variant; + v x1, x2; + make_valueless(x2); + EXPECT_TRUE(test_order(x1, x2, std::weak_ordering::greater)); + } + { + using v = variant; + v x1, x2; + make_valueless(x1); + EXPECT_TRUE(test_order(x1, x2, std::weak_ordering::less)); + } + { + using v = variant; + v x1, x2; + make_valueless(x1); + make_valueless(x2); + EXPECT_TRUE(test_order(x1, x2, std::weak_ordering::equivalent)); + } +} + +template +void test_with_types() { + using v = variant; + + static_assert(std::is_same() <=> std::declval()), Order>::value); + static_assert(std::is_same() <=> std::declval()), Order>::value); + static_assert(std::is_same() == std::declval()), bool>::value); + static_assert(std::is_same() == std::declval()), bool>::value); + static_assert(std::is_same() != std::declval()), bool>::value); + static_assert(std::is_same() != std::declval()), bool>::value); + static_assert(std::is_same() < std::declval()), bool>::value); + static_assert(std::is_same() < std::declval()), bool>::value); + static_assert(std::is_same() <= std::declval()), bool>::value); + static_assert(std::is_same() <= std::declval()), bool>::value); + static_assert(std::is_same() > std::declval()), bool>::value); + static_assert(std::is_same() > std::declval()), bool>::value); + static_assert(std::is_same() >= std::declval()), bool>::value); + static_assert(std::is_same() >= std::declval()), bool>::value); + + { + constexpr v x1(std::in_place_index<0>, T1 { 1 }); + constexpr v x2(std::in_place_index<0>, T1 { 1 }); + static_assert(test_order(x1, x2, Order::equivalent)); + } + { + constexpr v x1(std::in_place_index<0>, T1 { 0 }); + constexpr v x2(std::in_place_index<0>, T1 { 1 }); + static_assert(test_order(x1, x2, Order::less)); + } + { + constexpr v x1(std::in_place_index<0>, T1 { 1 }); + constexpr v x2(std::in_place_index<0>, T1 { 0 }); + static_assert(test_order(x1, x2, Order::greater)); + } + { + constexpr v x1(std::in_place_index<0>, T1 { 1 }); + constexpr v x2(std::in_place_index<1>, T1 { 1 }); + static_assert(test_order(x1, x2, Order::less)); + } + { + constexpr v x1(std::in_place_index<1>, T1 { 1 }); + constexpr v x2(std::in_place_index<0>, T1 { 1 }); + static_assert(test_order(x1, x2, Order::greater)); + } +} + +TEST(VariantTestThreeWayComparison, Basic) { + test_with_types(); + test_with_types(); + + { + using v = variant; + constexpr double nan = std::numeric_limits::quiet_NaN(); + { + constexpr v x1(std::in_place_type, 1); + constexpr v x2(std::in_place_type, nan); + EXPECT_TRUE(test_order(x1, x2, std::partial_ordering::less)); + } + { + constexpr v x1(std::in_place_type, nan); + constexpr v x2(std::in_place_type, 2); + EXPECT_TRUE(test_order(x1, x2, std::partial_ordering::greater)); + } + { + constexpr v x1(std::in_place_type, nan); + constexpr v x2(std::in_place_type, nan); + EXPECT_TRUE(test_order(x1, x2, std::partial_ordering::unordered)); + } + } +} + +template +concept has_three_way_op = requires(T& t, U& u) { t <=> u; }; + +TEST(VariantTestThreeWayComparison, Deleted) { + // std::three_way_comparable is a more stringent requirement that demands operator== and a few + // other things. + using std::three_way_comparable; + + struct has_simple_ordering { + constexpr auto operator==(const has_simple_ordering&) const -> bool; + constexpr auto operator<(const has_simple_ordering&) const -> bool; + }; + + struct has_only_spaceship { + constexpr auto operator==(const has_only_spaceship&) const -> bool = delete; + constexpr auto operator<=>(const has_only_spaceship&) const -> std::weak_ordering; + }; + + struct has_full_ordering { + constexpr auto operator==(const has_full_ordering&) const -> bool; + constexpr auto operator<=>(const has_full_ordering&) const -> std::weak_ordering; + }; + + // operator<=> must resolve the return types of all its union types' operator<=>s to determine + // its own return type, so it is detectable by SFINAE + static_assert(!has_three_way_op); + static_assert(!has_three_way_op>); + + static_assert(!three_way_comparable); + static_assert(!three_way_comparable>); + + static_assert(has_three_way_op); + static_assert(!has_three_way_op>); + + static_assert(!three_way_comparable); + static_assert(!three_way_comparable>); + + static_assert(has_three_way_op); + static_assert(has_three_way_op>); + + static_assert(three_way_comparable); + static_assert(three_way_comparable>); +} +} // namespace +} // namespace rust + +#endif diff --git a/unittest/variant/monostate_test.cpp b/unittest/variant/monostate_test.cpp index 795609c..60cfa7e 100644 --- a/unittest/variant/monostate_test.cpp +++ b/unittest/variant/monostate_test.cpp @@ -32,6 +32,11 @@ TEST(VariantTestMonostate, Compare) { static_assert(noexcept(m1 > m2)); static_assert(noexcept(m1 <= m2)); static_assert(noexcept(m1 >= m2)); + +#ifdef USE_CXX20 + static_assert(std::is_same m2), std::strong_ordering>::value); + static_assert(noexcept(m1 <=> m2)); +#endif } // non-constexpr @@ -51,6 +56,11 @@ TEST(VariantTestMonostate, Compare) { static_assert(noexcept(m1 > m2)); static_assert(noexcept(m1 <= m2)); static_assert(noexcept(m1 >= m2)); + +#ifdef USE_CXX20 + static_assert(std::is_same m2), std::strong_ordering>::value); + static_assert(noexcept(m1 <=> m2)); +#endif } } } // namespace