Skip to content

Commit

Permalink
variant: Implement the three-way comparison for monostate and variant.
Browse files Browse the repository at this point in the history
  • Loading branch information
Pluto-Zy committed Aug 7, 2023
1 parent 0b44265 commit 12ece8a
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 6 deletions.
65 changes: 59 additions & 6 deletions include/rust_enum/variant/comparison.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include <rust_enum/variant/detail/tagged_reference.hpp>
#include <rust_enum/variant/detail/tagged_visit.hpp>

#ifdef USE_CXX20
#include <compare>
#endif

namespace rust {
namespace detail {
template <class Op, class Variant>
Expand All @@ -32,7 +36,8 @@ struct variant_comparison_impl {

template <class... Tys>
CONSTEXPR17 auto operator==(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<std::is_nothrow_invocable<std::equal_to<>, const Tys&, const Tys&>...>::value
std::conjunction<
std::is_nothrow_invocable_r<bool, std::equal_to<>, const Tys&, const Tys&>...>::value
) -> bool {
return lhs.index() == rhs.index()
&& (lhs.valueless_by_exception()
Expand All @@ -45,7 +50,7 @@ CONSTEXPR17 auto operator==(const variant<Tys...>& lhs, const variant<Tys...>& r
template <class... Tys>
CONSTEXPR17 auto operator!=(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<
std::is_nothrow_invocable<std::not_equal_to<>, const Tys&, const Tys&>...>::value
std::is_nothrow_invocable_r<bool, std::not_equal_to<>, const Tys&, const Tys&>...>::value
) -> bool {
return lhs.index() != rhs.index()
|| (!lhs.valueless_by_exception()
Expand All @@ -57,7 +62,8 @@ CONSTEXPR17 auto operator!=(const variant<Tys...>& lhs, const variant<Tys...>& r

template <class... Tys>
CONSTEXPR17 auto operator<(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<std::is_nothrow_invocable<std::less<>, const Tys&, const Tys&>...>::value
std::conjunction<
std::is_nothrow_invocable_r<bool, std::less<>, const Tys&, const Tys&>...>::value
) -> bool {
if (rhs.valueless_by_exception())
return false;
Expand All @@ -73,7 +79,8 @@ CONSTEXPR17 auto operator<(const variant<Tys...>& lhs, const variant<Tys...>& rh

template <class... Tys>
CONSTEXPR17 auto operator>(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<std::is_nothrow_invocable<std::greater<>, const Tys&, const Tys&>...>::value
std::conjunction<
std::is_nothrow_invocable_r<bool, std::greater<>, const Tys&, const Tys&>...>::value
) -> bool {
if (lhs.valueless_by_exception())
return false;
Expand All @@ -90,7 +97,7 @@ CONSTEXPR17 auto operator>(const variant<Tys...>& lhs, const variant<Tys...>& rh
template <class... Tys>
CONSTEXPR17 auto operator<=(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<
std::is_nothrow_invocable<std::less_equal<>, const Tys&, const Tys&>...>::value
std::is_nothrow_invocable_r<bool, std::less_equal<>, const Tys&, const Tys&>...>::value
) -> bool {
if (lhs.valueless_by_exception())
return true;
Expand All @@ -107,7 +114,7 @@ CONSTEXPR17 auto operator<=(const variant<Tys...>& lhs, const variant<Tys...>& r
template <class... Tys>
CONSTEXPR17 auto operator>=(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<
std::is_nothrow_invocable<std::greater_equal<>, const Tys&, const Tys&>...>::value
std::is_nothrow_invocable_r<bool, std::greater_equal<>, const Tys&, const Tys&>...>::value
) -> bool {
if (rhs.valueless_by_exception())
return true;
Expand All @@ -120,6 +127,52 @@ CONSTEXPR17 auto operator>=(const variant<Tys...>& lhs, const variant<Tys...>& r
rhs.storage()
));
}

#ifdef USE_CXX20
namespace detail {
template <class Variant, class Ret>
struct variant_three_way_comparison_impl {
const Variant& lhs;

template <std::size_t Idx, class AltStorageRef>
CONSTEXPR20 auto operator()(tagged_reference<Idx, AltStorageRef> rhs) const -> Ret {
if constexpr (std::is_same<AltStorageRef, valueless_tag>::value) {
unreachable();
} else {
return std::compare_three_way {}(
get_variant_tagged_content<Idx>(lhs.storage()).forward_content().forward_value(),
rhs.forward_content().forward_value()
);
}
}
};
} // namespace detail

template <class... Tys>
requires(std::three_way_comparable<Tys> && ...)
CONSTEXPR20 auto operator<=>(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<std::is_nothrow_invocable_r<
std::common_comparison_category_t<std::compare_three_way_result_t<Tys>...>,
std::compare_three_way,
const Tys&,
const Tys&>...>::value
) -> std::common_comparison_category_t<std::compare_three_way_result_t<Tys>...> {
if (lhs.valueless_by_exception() && rhs.valueless_by_exception())
return std::strong_ordering::equal;
if (lhs.valueless_by_exception())
return std::strong_ordering::less;
if (rhs.valueless_by_exception())
return std::strong_ordering::greater;
if (auto const c = lhs.index() <=> rhs.index(); c != 0)
return c;

using ret_type = std::common_comparison_category_t<std::compare_three_way_result_t<Tys>...>;
return detail::tagged_visit(
detail::variant_three_way_comparison_impl<variant<Tys...>, ret_type> { lhs },
rhs.storage()
);
}
#endif // USE_CXX20
} // namespace rust

#endif // RUST_ENUM_INCLUDE_RUST_ENUM_VARIANT_COMPARISON_HPP
10 changes: 10 additions & 0 deletions include/rust_enum/variant/monostate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,22 @@

#include <rust_enum/variant/detail/macro.hpp>

#ifdef USE_CXX20
#include <compare>
#endif

namespace rust {
struct monostate { };

CONSTEXPR17 auto operator==(monostate, monostate) noexcept -> bool {
return true;
}

#ifdef USE_CXX20
CONSTEXPR20 auto operator<=>(monostate, monostate) noexcept -> std::strong_ordering {
return std::strong_ordering::equal;
}
#else
CONSTEXPR17 auto operator!=(monostate, monostate) noexcept -> bool {
return false;
}
Expand All @@ -29,6 +38,7 @@ CONSTEXPR17 auto operator<=(monostate, monostate) noexcept -> bool {
CONSTEXPR17 auto operator>=(monostate, monostate) noexcept -> bool {
return true;
}
#endif // USE_CXX20
} // namespace rust

#endif // RUST_ENUM_INCLUDE_RUST_ENUM_VARIANT_MONOSTATE_HPP
1 change: 1 addition & 0 deletions unittest/variant/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ rust_enum_add_unittest(variant_test
get_test.cpp
holds_alternative_test.cpp
comparison_test.cpp
comparison_three_way_test.cpp
assignment_test/converting_assignment_test.cpp
assignment_test/copy_assignment_test.cpp
assignment_test/move_assignment_test.cpp
Expand Down
7 changes: 7 additions & 0 deletions unittest/variant/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ inline auto operator>=(const valueless_t&, const valueless_t&) -> bool {
return false;
}

#ifdef USE_CXX20
inline auto operator<=>(const valueless_t&, const valueless_t&) -> std::weak_ordering {
ADD_FAILURE();
return std::weak_ordering::equivalent;
}
#endif

template <class Variant>
void make_valueless(Variant& v) {
Variant valueless(std::in_place_type<valueless_t>);
Expand Down
177 changes: 177 additions & 0 deletions unittest/variant/comparison_three_way_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
#include "common.hpp"

#ifdef USE_CXX20
#include <compare>

namespace rust {
namespace {
template <class T1, class T2>
constexpr auto test_comparison_complete( //
const T1& lhs,
const T2& rhs,
std::partial_ordering order
) -> bool {
#define CHECK_BODY(op) \
if ((lhs op rhs) != (order op 0)) \
return false; \
if ((rhs op lhs) != (0 op order)) \
return false;

CHECK_BODY(==)
CHECK_BODY(!=)
CHECK_BODY(<)
CHECK_BODY(>)
CHECK_BODY(<=)
CHECK_BODY(>=)

return true;
#undef CHECK_BODY
}

template <class T1, class T2, class Order>
constexpr auto test_order(const T1& lhs, const T2& rhs, Order order) -> bool {
if ((lhs <=> rhs) != order)
return false;
return test_comparison_complete(lhs, rhs, order);
}

TEST(VariantTestThreeWayComparison, Valueless) {
{
using v = variant<int, valueless_t>;
v x1, x2;
make_valueless(x2);
EXPECT_TRUE(test_order(x1, x2, std::weak_ordering::greater));
}
{
using v = variant<int, valueless_t>;
v x1, x2;
make_valueless(x1);
EXPECT_TRUE(test_order(x1, x2, std::weak_ordering::less));
}
{
using v = variant<int, valueless_t>;
v x1, x2;
make_valueless(x1);
make_valueless(x2);
EXPECT_TRUE(test_order(x1, x2, std::weak_ordering::equivalent));
}
}

template <class T1, class T2, class Order>
void test_with_types() {
using v = variant<T1, T2>;

static_assert(std::is_same<decltype(std::declval<T1>() <=> std::declval<T2>()), Order>::value);
static_assert(std::is_same<decltype(std::declval<T2>() <=> std::declval<T1>()), Order>::value);
static_assert(std::is_same<decltype(std::declval<T1>() == std::declval<T2>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T2>() == std::declval<T1>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T1>() != std::declval<T2>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T2>() != std::declval<T1>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T1>() < std::declval<T2>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T2>() < std::declval<T1>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T1>() <= std::declval<T2>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T2>() <= std::declval<T1>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T1>() > std::declval<T2>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T2>() > std::declval<T1>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T1>() >= std::declval<T2>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T2>() >= std::declval<T1>()), bool>::value);

{
constexpr v x1(std::in_place_index<0>, T1 { 1 });
constexpr v x2(std::in_place_index<0>, T1 { 1 });
static_assert(test_order(x1, x2, Order::equivalent));
}
{
constexpr v x1(std::in_place_index<0>, T1 { 0 });
constexpr v x2(std::in_place_index<0>, T1 { 1 });
static_assert(test_order(x1, x2, Order::less));
}
{
constexpr v x1(std::in_place_index<0>, T1 { 1 });
constexpr v x2(std::in_place_index<0>, T1 { 0 });
static_assert(test_order(x1, x2, Order::greater));
}
{
constexpr v x1(std::in_place_index<0>, T1 { 1 });
constexpr v x2(std::in_place_index<1>, T1 { 1 });
static_assert(test_order(x1, x2, Order::less));
}
{
constexpr v x1(std::in_place_index<1>, T1 { 1 });
constexpr v x2(std::in_place_index<0>, T1 { 1 });
static_assert(test_order(x1, x2, Order::greater));
}
}

TEST(VariantTestThreeWayComparison, Basic) {
test_with_types<int, double, std::partial_ordering>();
test_with_types<int, long, std::strong_ordering>();

{
using v = variant<int, double>;
constexpr double nan = std::numeric_limits<double>::quiet_NaN();
{
constexpr v x1(std::in_place_type<int>, 1);
constexpr v x2(std::in_place_type<double>, nan);
EXPECT_TRUE(test_order(x1, x2, std::partial_ordering::less));
}
{
constexpr v x1(std::in_place_type<double>, nan);
constexpr v x2(std::in_place_type<int>, 2);
EXPECT_TRUE(test_order(x1, x2, std::partial_ordering::greater));
}
{
constexpr v x1(std::in_place_type<double>, nan);
constexpr v x2(std::in_place_type<double>, nan);
EXPECT_TRUE(test_order(x1, x2, std::partial_ordering::unordered));
}
}
}

template <class T, class U = T>
concept has_three_way_op = requires(T& t, U& u) { t <=> u; };

TEST(VariantTestThreeWayComparison, Deleted) {
// std::three_way_comparable is a more stringent requirement that demands operator== and a few
// other things.
using std::three_way_comparable;

struct has_simple_ordering {
constexpr auto operator==(const has_simple_ordering&) const -> bool;
constexpr auto operator<(const has_simple_ordering&) const -> bool;
};

struct has_only_spaceship {
constexpr auto operator==(const has_only_spaceship&) const -> bool = delete;
constexpr auto operator<=>(const has_only_spaceship&) const -> std::weak_ordering;
};

struct has_full_ordering {
constexpr auto operator==(const has_full_ordering&) const -> bool;
constexpr auto operator<=>(const has_full_ordering&) const -> std::weak_ordering;
};

// operator<=> must resolve the return types of all its union types' operator<=>s to determine
// its own return type, so it is detectable by SFINAE
static_assert(!has_three_way_op<has_simple_ordering>);
static_assert(!has_three_way_op<variant<int, has_simple_ordering>>);

static_assert(!three_way_comparable<has_simple_ordering>);
static_assert(!three_way_comparable<variant<int, has_simple_ordering>>);

static_assert(has_three_way_op<has_only_spaceship>);
static_assert(!has_three_way_op<variant<int, has_only_spaceship>>);

static_assert(!three_way_comparable<has_only_spaceship>);
static_assert(!three_way_comparable<variant<int, has_only_spaceship>>);

static_assert(has_three_way_op<has_full_ordering>);
static_assert(has_three_way_op<variant<int, has_full_ordering>>);

static_assert(three_way_comparable<has_full_ordering>);
static_assert(three_way_comparable<variant<int, has_full_ordering>>);
}
} // namespace
} // namespace rust

#endif
10 changes: 10 additions & 0 deletions unittest/variant/monostate_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ TEST(VariantTestMonostate, Compare) {
static_assert(noexcept(m1 > m2));
static_assert(noexcept(m1 <= m2));
static_assert(noexcept(m1 >= m2));

#ifdef USE_CXX20
static_assert(std::is_same<decltype(m1 <=> m2), std::strong_ordering>::value);
static_assert(noexcept(m1 <=> m2));
#endif
}

// non-constexpr
Expand All @@ -51,6 +56,11 @@ TEST(VariantTestMonostate, Compare) {
static_assert(noexcept(m1 > m2));
static_assert(noexcept(m1 <= m2));
static_assert(noexcept(m1 >= m2));

#ifdef USE_CXX20
static_assert(std::is_same<decltype(m1 <=> m2), std::strong_ordering>::value);
static_assert(noexcept(m1 <=> m2));
#endif
}
}
} // namespace
Expand Down

0 comments on commit 12ece8a

Please sign in to comment.