Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

variant: Implement the three-way comparison for monostate and variant. #3

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions include/rust_enum/variant/comparison.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include <rust_enum/variant/detail/tagged_reference.hpp>
#include <rust_enum/variant/detail/tagged_visit.hpp>

#ifdef USE_CXX20
#include <compare>
#endif

namespace rust {
namespace detail {
template <class Op, class Variant>
Expand All @@ -32,7 +36,8 @@ struct variant_comparison_impl {

template <class... Tys>
CONSTEXPR17 auto operator==(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<std::is_nothrow_invocable<std::equal_to<>, const Tys&, const Tys&>...>::value
std::conjunction<
std::is_nothrow_invocable_r<bool, std::equal_to<>, const Tys&, const Tys&>...>::value
) -> bool {
return lhs.index() == rhs.index()
&& (lhs.valueless_by_exception()
Expand All @@ -45,7 +50,7 @@ CONSTEXPR17 auto operator==(const variant<Tys...>& lhs, const variant<Tys...>& r
template <class... Tys>
CONSTEXPR17 auto operator!=(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<
std::is_nothrow_invocable<std::not_equal_to<>, const Tys&, const Tys&>...>::value
std::is_nothrow_invocable_r<bool, std::not_equal_to<>, const Tys&, const Tys&>...>::value
) -> bool {
return lhs.index() != rhs.index()
|| (!lhs.valueless_by_exception()
Expand All @@ -57,7 +62,8 @@ CONSTEXPR17 auto operator!=(const variant<Tys...>& lhs, const variant<Tys...>& r

template <class... Tys>
CONSTEXPR17 auto operator<(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<std::is_nothrow_invocable<std::less<>, const Tys&, const Tys&>...>::value
std::conjunction<
std::is_nothrow_invocable_r<bool, std::less<>, const Tys&, const Tys&>...>::value
) -> bool {
if (rhs.valueless_by_exception())
return false;
Expand All @@ -73,7 +79,8 @@ CONSTEXPR17 auto operator<(const variant<Tys...>& lhs, const variant<Tys...>& rh

template <class... Tys>
CONSTEXPR17 auto operator>(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<std::is_nothrow_invocable<std::greater<>, const Tys&, const Tys&>...>::value
std::conjunction<
std::is_nothrow_invocable_r<bool, std::greater<>, const Tys&, const Tys&>...>::value
) -> bool {
if (lhs.valueless_by_exception())
return false;
Expand All @@ -90,7 +97,7 @@ CONSTEXPR17 auto operator>(const variant<Tys...>& lhs, const variant<Tys...>& rh
template <class... Tys>
CONSTEXPR17 auto operator<=(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<
std::is_nothrow_invocable<std::less_equal<>, const Tys&, const Tys&>...>::value
std::is_nothrow_invocable_r<bool, std::less_equal<>, const Tys&, const Tys&>...>::value
) -> bool {
if (lhs.valueless_by_exception())
return true;
Expand All @@ -107,7 +114,7 @@ CONSTEXPR17 auto operator<=(const variant<Tys...>& lhs, const variant<Tys...>& r
template <class... Tys>
CONSTEXPR17 auto operator>=(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<
std::is_nothrow_invocable<std::greater_equal<>, const Tys&, const Tys&>...>::value
std::is_nothrow_invocable_r<bool, std::greater_equal<>, const Tys&, const Tys&>...>::value
) -> bool {
if (rhs.valueless_by_exception())
return true;
Expand All @@ -120,6 +127,52 @@ CONSTEXPR17 auto operator>=(const variant<Tys...>& lhs, const variant<Tys...>& r
rhs.storage()
));
}

#ifdef USE_CXX20
namespace detail {
template <class Variant, class Ret>
struct variant_three_way_comparison_impl {
const Variant& lhs;

template <std::size_t Idx, class AltStorageRef>
CONSTEXPR20 auto operator()(tagged_reference<Idx, AltStorageRef> rhs) const -> Ret {
if constexpr (std::is_same<AltStorageRef, valueless_tag>::value) {
unreachable();
} else {
return std::compare_three_way {}(
get_variant_tagged_content<Idx>(lhs.storage()).forward_content().forward_value(),
rhs.forward_content().forward_value()
);
}
}
};
} // namespace detail

template <class... Tys>
requires(std::three_way_comparable<Tys> && ...)
CONSTEXPR20 auto operator<=>(const variant<Tys...>& lhs, const variant<Tys...>& rhs) noexcept(
std::conjunction<std::is_nothrow_invocable_r<
std::common_comparison_category_t<std::compare_three_way_result_t<Tys>...>,
std::compare_three_way,
const Tys&,
const Tys&>...>::value
) -> std::common_comparison_category_t<std::compare_three_way_result_t<Tys>...> {
if (lhs.valueless_by_exception() && rhs.valueless_by_exception())
return std::strong_ordering::equal;
if (lhs.valueless_by_exception())
return std::strong_ordering::less;
if (rhs.valueless_by_exception())
return std::strong_ordering::greater;
if (auto const c = lhs.index() <=> rhs.index(); c != 0)
return c;

using ret_type = std::common_comparison_category_t<std::compare_three_way_result_t<Tys>...>;
return detail::tagged_visit(
detail::variant_three_way_comparison_impl<variant<Tys...>, ret_type> { lhs },
rhs.storage()
);
}
#endif // USE_CXX20
} // namespace rust

#endif // RUST_ENUM_INCLUDE_RUST_ENUM_VARIANT_COMPARISON_HPP
10 changes: 10 additions & 0 deletions include/rust_enum/variant/monostate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,22 @@

#include <rust_enum/variant/detail/macro.hpp>

#ifdef USE_CXX20
#include <compare>
#endif

namespace rust {
struct monostate { };

CONSTEXPR17 auto operator==(monostate, monostate) noexcept -> bool {
return true;
}

#ifdef USE_CXX20
CONSTEXPR20 auto operator<=>(monostate, monostate) noexcept -> std::strong_ordering {
return std::strong_ordering::equal;
}
#else
CONSTEXPR17 auto operator!=(monostate, monostate) noexcept -> bool {
return false;
}
Expand All @@ -29,6 +38,7 @@ CONSTEXPR17 auto operator<=(monostate, monostate) noexcept -> bool {
CONSTEXPR17 auto operator>=(monostate, monostate) noexcept -> bool {
return true;
}
#endif // USE_CXX20
} // namespace rust

#endif // RUST_ENUM_INCLUDE_RUST_ENUM_VARIANT_MONOSTATE_HPP
1 change: 1 addition & 0 deletions unittest/variant/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ rust_enum_add_unittest(variant_test
get_test.cpp
holds_alternative_test.cpp
comparison_test.cpp
comparison_three_way_test.cpp
assignment_test/converting_assignment_test.cpp
assignment_test/copy_assignment_test.cpp
assignment_test/move_assignment_test.cpp
Expand Down
7 changes: 7 additions & 0 deletions unittest/variant/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ inline auto operator>=(const valueless_t&, const valueless_t&) -> bool {
return false;
}

#ifdef USE_CXX20
inline auto operator<=>(const valueless_t&, const valueless_t&) -> std::weak_ordering {
ADD_FAILURE();
return std::weak_ordering::equivalent;
}
#endif

template <class Variant>
void make_valueless(Variant& v) {
Variant valueless(std::in_place_type<valueless_t>);
Expand Down
177 changes: 177 additions & 0 deletions unittest/variant/comparison_three_way_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
#include "common.hpp"

#ifdef USE_CXX20
#include <compare>

namespace rust {
namespace {
template <class T1, class T2>
constexpr auto test_comparison_complete( //
const T1& lhs,
const T2& rhs,
std::partial_ordering order
) -> bool {
#define CHECK_BODY(op) \
if ((lhs op rhs) != (order op 0)) \
return false; \
if ((rhs op lhs) != (0 op order)) \
return false;

CHECK_BODY(==)
CHECK_BODY(!=)
CHECK_BODY(<)
CHECK_BODY(>)
CHECK_BODY(<=)
CHECK_BODY(>=)

return true;
#undef CHECK_BODY
}

template <class T1, class T2, class Order>
constexpr auto test_order(const T1& lhs, const T2& rhs, Order order) -> bool {
if ((lhs <=> rhs) != order)
return false;
return test_comparison_complete(lhs, rhs, order);
}

TEST(VariantTestThreeWayComparison, Valueless) {
{
using v = variant<int, valueless_t>;
v x1, x2;
make_valueless(x2);
EXPECT_TRUE(test_order(x1, x2, std::weak_ordering::greater));
}
{
using v = variant<int, valueless_t>;
v x1, x2;
make_valueless(x1);
EXPECT_TRUE(test_order(x1, x2, std::weak_ordering::less));
}
{
using v = variant<int, valueless_t>;
v x1, x2;
make_valueless(x1);
make_valueless(x2);
EXPECT_TRUE(test_order(x1, x2, std::weak_ordering::equivalent));
}
}

template <class T1, class T2, class Order>
void test_with_types() {
using v = variant<T1, T2>;

static_assert(std::is_same<decltype(std::declval<T1>() <=> std::declval<T2>()), Order>::value);
static_assert(std::is_same<decltype(std::declval<T2>() <=> std::declval<T1>()), Order>::value);
static_assert(std::is_same<decltype(std::declval<T1>() == std::declval<T2>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T2>() == std::declval<T1>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T1>() != std::declval<T2>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T2>() != std::declval<T1>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T1>() < std::declval<T2>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T2>() < std::declval<T1>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T1>() <= std::declval<T2>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T2>() <= std::declval<T1>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T1>() > std::declval<T2>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T2>() > std::declval<T1>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T1>() >= std::declval<T2>()), bool>::value);
static_assert(std::is_same<decltype(std::declval<T2>() >= std::declval<T1>()), bool>::value);

{
constexpr v x1(std::in_place_index<0>, T1 { 1 });
constexpr v x2(std::in_place_index<0>, T1 { 1 });
static_assert(test_order(x1, x2, Order::equivalent));
}
{
constexpr v x1(std::in_place_index<0>, T1 { 0 });
constexpr v x2(std::in_place_index<0>, T1 { 1 });
static_assert(test_order(x1, x2, Order::less));
}
{
constexpr v x1(std::in_place_index<0>, T1 { 1 });
constexpr v x2(std::in_place_index<0>, T1 { 0 });
static_assert(test_order(x1, x2, Order::greater));
}
{
constexpr v x1(std::in_place_index<0>, T1 { 1 });
constexpr v x2(std::in_place_index<1>, T1 { 1 });
static_assert(test_order(x1, x2, Order::less));
}
{
constexpr v x1(std::in_place_index<1>, T1 { 1 });
constexpr v x2(std::in_place_index<0>, T1 { 1 });
static_assert(test_order(x1, x2, Order::greater));
}
}

TEST(VariantTestThreeWayComparison, Basic) {
test_with_types<int, double, std::partial_ordering>();
test_with_types<int, long, std::strong_ordering>();

{
using v = variant<int, double>;
constexpr double nan = std::numeric_limits<double>::quiet_NaN();
{
constexpr v x1(std::in_place_type<int>, 1);
constexpr v x2(std::in_place_type<double>, nan);
EXPECT_TRUE(test_order(x1, x2, std::partial_ordering::less));
}
{
constexpr v x1(std::in_place_type<double>, nan);
constexpr v x2(std::in_place_type<int>, 2);
EXPECT_TRUE(test_order(x1, x2, std::partial_ordering::greater));
}
{
constexpr v x1(std::in_place_type<double>, nan);
constexpr v x2(std::in_place_type<double>, nan);
EXPECT_TRUE(test_order(x1, x2, std::partial_ordering::unordered));
}
}
}

template <class T, class U = T>
concept has_three_way_op = requires(T& t, U& u) { t <=> u; };

TEST(VariantTestThreeWayComparison, Deleted) {
// std::three_way_comparable is a more stringent requirement that demands operator== and a few
// other things.
using std::three_way_comparable;

struct has_simple_ordering {
constexpr auto operator==(const has_simple_ordering&) const -> bool;
constexpr auto operator<(const has_simple_ordering&) const -> bool;
};

struct has_only_spaceship {
constexpr auto operator==(const has_only_spaceship&) const -> bool = delete;
constexpr auto operator<=>(const has_only_spaceship&) const -> std::weak_ordering;
};

struct has_full_ordering {
constexpr auto operator==(const has_full_ordering&) const -> bool;
constexpr auto operator<=>(const has_full_ordering&) const -> std::weak_ordering;
};

// operator<=> must resolve the return types of all its union types' operator<=>s to determine
// its own return type, so it is detectable by SFINAE
static_assert(!has_three_way_op<has_simple_ordering>);
static_assert(!has_three_way_op<variant<int, has_simple_ordering>>);

static_assert(!three_way_comparable<has_simple_ordering>);
static_assert(!three_way_comparable<variant<int, has_simple_ordering>>);

static_assert(has_three_way_op<has_only_spaceship>);
static_assert(!has_three_way_op<variant<int, has_only_spaceship>>);

static_assert(!three_way_comparable<has_only_spaceship>);
static_assert(!three_way_comparable<variant<int, has_only_spaceship>>);

static_assert(has_three_way_op<has_full_ordering>);
static_assert(has_three_way_op<variant<int, has_full_ordering>>);

static_assert(three_way_comparable<has_full_ordering>);
static_assert(three_way_comparable<variant<int, has_full_ordering>>);
}
} // namespace
} // namespace rust

#endif
10 changes: 10 additions & 0 deletions unittest/variant/monostate_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ TEST(VariantTestMonostate, Compare) {
static_assert(noexcept(m1 > m2));
static_assert(noexcept(m1 <= m2));
static_assert(noexcept(m1 >= m2));

#ifdef USE_CXX20
static_assert(std::is_same<decltype(m1 <=> m2), std::strong_ordering>::value);
static_assert(noexcept(m1 <=> m2));
#endif
}

// non-constexpr
Expand All @@ -51,6 +56,11 @@ TEST(VariantTestMonostate, Compare) {
static_assert(noexcept(m1 > m2));
static_assert(noexcept(m1 <= m2));
static_assert(noexcept(m1 >= m2));

#ifdef USE_CXX20
static_assert(std::is_same<decltype(m1 <=> m2), std::strong_ordering>::value);
static_assert(noexcept(m1 <=> m2));
#endif
}
}
} // namespace
Expand Down
Loading