Skip to content

Commit

Permalink
Refactor helper function signatures. Removed SFINAE test.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Oct 2, 2024
1 parent d7743ee commit 61db7f8
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 95 deletions.
13 changes: 6 additions & 7 deletions src/atlas/array/ArrayViewVariant.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,20 @@ ArrayViewVariant make_device_view_variant(Array& array);
ConstArrayViewVariant make_device_view_variant(const Array& array);

/// @brief Return true if View::rank() is any of Ranks...
template <typename View, int... Ranks>
constexpr bool is_rank() {
template <int... Ranks, typename View>
constexpr bool is_rank(const View&) {
return ((std::decay_t<View>::rank() == Ranks) || ...);
}

/// @brief Return true if View::value_type is any of ValuesTypes...
template <typename View, typename... ValueTypes>
constexpr bool is_value_type() {
template <typename... ValueTypes, typename View>
constexpr bool is_value_type(const View&) {
using ValueType = typename std::decay_t<View>::value_type;
return ((std::is_same_v<ValueType, ValueTypes>) || ...);
}

/// @brief Return true if View::non_const_value_type is any of ValuesTypes...
template <typename View, typename... ValueTypes>
constexpr bool is_non_const_value_type() {
template <typename... ValueTypes, typename View>
constexpr bool is_non_const_value_type(const View&) {
using ValueType = typename std::decay_t<View>::non_const_value_type;
return ((std::is_same_v<ValueType, ValueTypes>) || ...);
}
Expand Down
130 changes: 42 additions & 88 deletions src/tests/array/test_array_view_variant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,40 +43,38 @@ CASE("test variant assignment") {
auto deviceView3 = make_device_view_variant(array3);
auto deviceView4 = make_device_view_variant(arrayRef);

auto view = make_view<float, 1>(array1);

const auto visitVariants = [](auto& var1, auto& var2, auto var3, auto var4) {
std::visit(
[](auto&& view) {
using View = decltype(view);
EXPECT((is_rank<View, 1>()));
EXPECT((is_value_type<View, float>()));
EXPECT((is_non_const_value_type<View, float>()));
EXPECT((is_rank<1>(view)));
EXPECT((is_value_type<float>(view)));
EXPECT((is_non_const_value_type<float>(view)));
},
var1);

std::visit(
[](auto&& view) {
using View = decltype(view);
EXPECT((is_rank<View, 2>()));
EXPECT((is_value_type<View, double>()));
EXPECT((is_non_const_value_type<View, double>()));
EXPECT((is_rank<2>(view)));
EXPECT((is_value_type<double>(view)));
EXPECT((is_non_const_value_type<double>(view)));
},
var2);

std::visit(
[](auto&& view) {
using View = decltype(view);
EXPECT((is_rank<View, 3>()));
EXPECT((is_value_type<View, int>()));
EXPECT((is_non_const_value_type<View, int>()));
EXPECT((is_rank<3>(view)));
EXPECT((is_value_type<int>(view)));
EXPECT((is_non_const_value_type<int>(view)));
},
var3);

std::visit(
[](auto&& view) {
using View = decltype(view);
EXPECT((is_rank<View, 1>()));
EXPECT((is_value_type<View, const float>()));
EXPECT((is_non_const_value_type<View, float>()));
EXPECT((is_rank<1>(view)));
EXPECT((is_value_type<const float>(view)));
EXPECT((is_non_const_value_type<float>(view)));
},
var4);
};
Expand All @@ -93,84 +91,40 @@ CASE("test std::visit") {
auto array2 = ArrayT<int>(5, 2);
make_view<int, 2>(array2).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});

const auto testRank1 = [](auto&& view) {


using View = std::remove_reference_t<decltype(view)>;
using Value = typename View::value_type;

EXPECT((is_rank<View, 1>()));
EXPECT((is_value_type<View, int>()));
for (auto i = size_t{0}; i < view.size(); ++i) {
EXPECT_EQ(view(i), static_cast<Value>(i));
}
};

const auto testRank2 = [](auto&& view) {
using View = std::remove_reference_t<decltype(view)>;
using Value = typename View::value_type;

EXPECT((is_rank<View, 2>()));
EXPECT((is_value_type<View, int>()));

auto testValue = int{0};
for (auto i = idx_t{0}; i < view.shape(0); ++i) {
for (auto j = idx_t{0}; j < view.shape(1); ++j) {
EXPECT_EQ(view(i, j), static_cast<Value>(testValue++));
}
}
};

const auto var1 = make_view_variant(array1);
const auto var2 = make_view_variant(array2);

SECTION("demonstrate 'if constexpr' pattern") {
auto rank1Tested = false;
auto rank2Tested = false;

const auto visitor = [&](auto&& view) {
if constexpr (is_rank<decltype(view), 1>()) {
rank1Tested = true;
return testRank1(view);
auto rank1Tested = false;
auto rank2Tested = false;

const auto visitor = [&](auto&& view) {
if constexpr (is_rank<1>(view)) {
EXPECT((is_value_type<int>(view)));
auto testValue = int{0};
for (auto i = size_t{0}; i < view.size(); ++i) {
const auto value = view(i);
EXPECT_EQ(value, static_cast<decltype(value)>(testValue++));
}
if constexpr (is_rank<decltype(view), 2>()) {
rank2Tested = true;
return testRank2(view);
rank1Tested = true;
} else if constexpr (is_rank<2>(view)) {
EXPECT((is_value_type<int>(view)));
auto testValue = int{0};
for (auto i = idx_t{0}; i < view.shape(0); ++i) {
for (auto j = idx_t{0}; j < view.shape(1); ++j) {
const auto value = view(i, j);
EXPECT_EQ(value, static_cast<decltype(value)>(testValue++));
}
}
rank2Tested = true;
} else {
// Test should not reach here.
EXPECT(false);
};

std::visit(visitor, var1);
EXPECT(rank1Tested);
std::visit(visitor, var2);
EXPECT(rank2Tested);
}

SECTION("demonstrate 'overloaded' pattern") {
// Note, SFINAE can be eliminated using concepts and explicit lambda
// templates in C++20.
auto rank1Tested = false;
auto rank2Tested = false;
const auto visitor = eckit::Overloaded{
[&](auto&& view) -> std::enable_if_t<is_rank<decltype(view), 1>()> {
testRank1(view);
rank1Tested = true;
},
[&](auto&& view) -> std::enable_if_t<is_rank<decltype(view), 2>()> {
testRank2(view);
rank2Tested = true;
},
[](auto&& view) -> std::enable_if_t<!is_rank<decltype(view), 1, 2>()> {
// Test should not reach here.
EXPECT(false);
}};

std::visit(visitor, var1);
EXPECT(rank1Tested);
std::visit(visitor, var2);
EXPECT(rank2Tested);
}
}
};

std::visit(visitor, var1);
EXPECT(rank1Tested);
std::visit(visitor, var2);
EXPECT(rank2Tested);
}

} // namespace test
Expand Down

0 comments on commit 61db7f8

Please sign in to comment.