Skip to content

Commit

Permalink
Refactored introspection helpers.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Oct 2, 2024
1 parent 3590f0c commit d7743ee
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 35 deletions.
27 changes: 17 additions & 10 deletions src/atlas/array/ArrayViewVariant.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ using Variant = typename VariantHelper<Values, Ranks>::type;
} // namespace detail

/// @brief Supported ArrayView value types.
using Values = detail::Types<float, double, int, long, unsigned long>;
using ValueTypes = detail::Types<float, double, int, long, unsigned long>;

/// @brief Supported ArrayView ranks.
using Ranks = detail::Ints<1, 2, 3, 4, 5, 6, 7, 8, 9>;

/// @brief Variant containing all supported non-const ArrayView alternatives.
using ArrayViewVariant = detail::Variant<Values, Ranks>;
using ArrayViewVariant = detail::Variant<ValueTypes, Ranks>;

/// @brief Variant containing all supported const ArrayView alternatives.
using ConstArrayViewVariant = detail::Variant<Values::add_const, Ranks>;
using ConstArrayViewVariant = detail::Variant<ValueTypes::add_const, Ranks>;

/// @brief Create an ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_view_variant(Array& array);
Expand All @@ -79,17 +79,24 @@ ArrayViewVariant make_device_view_variant(Array& array);
/// @brief Create a const device ArrayView and assign to an ArrayViewVariant.
ConstArrayViewVariant make_device_view_variant(const Array& array);

/// @brief Return true if ArrayView<typename, int>::rank() is any of Ranks...
/// @brief Return true if View::rank() is any of Ranks...
template <typename View, int... Ranks>
constexpr bool RankIs() {
constexpr bool is_rank() {
return ((std::decay_t<View>::rank() == Ranks) || ...);
}

/// @brief Return true if View::non_const_value_type is any of Values...
template <typename View, typename... Values>
constexpr bool ValueIs() {
using Value = typename std::decay_t<View>::non_const_value_type;
return ((std::is_same_v<Value, Values>) || ...);
/// @brief Return true if View::value_type is any of ValuesTypes...
template <typename View, typename... ValueTypes>
constexpr bool is_value_type() {
using ValueType = typename std::decay_t<View>::value_type;
return ((std::is_same_v<ValueType, ValueTypes>) || ...);
}

/// @brief Return true if View::non_const_value_type is any of ValuesTypes...
template <typename View, typename... ValueTypes>
constexpr bool is_non_const_value_type() {
using ValueType = typename std::decay_t<View>::non_const_value_type;
return ((std::is_same_v<ValueType, ValueTypes>) || ...);
}

} // namespace array
Expand Down
57 changes: 32 additions & 25 deletions src/tests/array/test_array_view_variant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ CASE("test variant assignment") {
auto view3 = make_view_variant(array3);
auto view4 = make_view_variant(arrayRef);

auto hostView1 = make_host_view_variant(array1);
auto hostView2 = make_host_view_variant(array2);
auto hostView3 = make_host_view_variant(array3);
auto hostView4 = make_host_view_variant(arrayRef);
const auto hostView1 = make_host_view_variant(array1);
const auto hostView2 = make_host_view_variant(array2);
const auto hostView3 = make_host_view_variant(array3);
const auto hostView4 = make_host_view_variant(arrayRef);

auto deviceView1 = make_device_view_variant(array1);
auto deviceView2 = make_device_view_variant(array2);
Expand All @@ -46,33 +46,37 @@ CASE("test variant assignment") {
const auto visitVariants = [](auto& var1, auto& var2, auto var3, auto var4) {
std::visit(
[](auto&& view) {
using View = std::remove_reference_t<decltype(view)>;
EXPECT_EQ(View::rank(), 1);
EXPECT((std::is_same_v<typename View::value_type, float>));
using View = decltype(view);
EXPECT((is_rank<View, 1>()));
EXPECT((is_value_type<View, float>()));
EXPECT((is_non_const_value_type<View, float>()));
},
var1);

std::visit(
[](auto&& view) {
using View = std::remove_reference_t<decltype(view)>;
EXPECT_EQ(View::rank(), 2);
EXPECT((std::is_same_v<typename View::value_type, double>));
using View = decltype(view);
EXPECT((is_rank<View, 2>()));
EXPECT((is_value_type<View, double>()));
EXPECT((is_non_const_value_type<View, double>()));
},
var2);

std::visit(
[](auto&& view) {
using View = std::remove_reference_t<decltype(view)>;
EXPECT_EQ(View::rank(), 3);
EXPECT((std::is_same_v<typename View::value_type, int>));
using View = decltype(view);
EXPECT((is_rank<View, 3>()));
EXPECT((is_value_type<View, int>()));
EXPECT((is_non_const_value_type<View, int>()));
},
var3);

std::visit(
[](auto&& view) {
using View = std::remove_reference_t<decltype(view)>;
EXPECT_EQ(View::rank(), 1);
EXPECT((std::is_same_v<typename View::value_type, const float>));
using View = decltype(view);
EXPECT((is_rank<View, 1>()));
EXPECT((is_value_type<View, const float>()));
EXPECT((is_non_const_value_type<View, float>()));
},
var4);
};
Expand All @@ -90,21 +94,24 @@ CASE("test std::visit") {
make_view<int, 2>(array2).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});

const auto testRank1 = [](auto&& view) {


using View = std::remove_reference_t<decltype(view)>;
EXPECT_EQ(View::rank(), 1);
using Value = typename View::value_type;
EXPECT((std::is_same_v<Value, int>));

EXPECT((is_rank<View, 1>()));
EXPECT((is_value_type<View, int>()));
for (auto i = size_t{0}; i < view.size(); ++i) {
EXPECT_EQ(view(i), static_cast<Value>(i));
}
};

const auto testRank2 = [](auto&& view) {
using View = std::remove_reference_t<decltype(view)>;
EXPECT_EQ(View::rank(), 2);
using Value = typename View::value_type;
EXPECT((std::is_same_v<Value, int>));

EXPECT((is_rank<View, 2>()));
EXPECT((is_value_type<View, int>()));

auto testValue = int{0};
for (auto i = idx_t{0}; i < view.shape(0); ++i) {
Expand All @@ -122,11 +129,11 @@ CASE("test std::visit") {
auto rank2Tested = false;

const auto visitor = [&](auto&& view) {
if constexpr (RankIs<decltype(view), 1>()) {
if constexpr (is_rank<decltype(view), 1>()) {
rank1Tested = true;
return testRank1(view);
}
if constexpr (RankIs<decltype(view), 2>()) {
if constexpr (is_rank<decltype(view), 2>()) {
rank2Tested = true;
return testRank2(view);
}
Expand All @@ -146,15 +153,15 @@ CASE("test std::visit") {
auto rank1Tested = false;
auto rank2Tested = false;
const auto visitor = eckit::Overloaded{
[&](auto&& view) -> std::enable_if_t<RankIs<decltype(view), 1>()> {
[&](auto&& view) -> std::enable_if_t<is_rank<decltype(view), 1>()> {
testRank1(view);
rank1Tested = true;
},
[&](auto&& view) -> std::enable_if_t<RankIs<decltype(view), 2>()> {
[&](auto&& view) -> std::enable_if_t<is_rank<decltype(view), 2>()> {
testRank2(view);
rank2Tested = true;
},
[](auto&& view) -> std::enable_if_t<!RankIs<decltype(view), 1, 2>()> {
[](auto&& view) -> std::enable_if_t<!is_rank<decltype(view), 1, 2>()> {
// Test should not reach here.
EXPECT(false);
}};
Expand Down

0 comments on commit d7743ee

Please sign in to comment.