Skip to content

Commit

Permalink
Updated test.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Sep 13, 2024
1 parent c3638d3 commit cbf7818
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 154 deletions.
15 changes: 9 additions & 6 deletions src/atlas/array/ArrayViewVariant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ namespace {
template <size_t TypeIndex = 0, typename ArrayType, typename MakeView>
ArrayViewVariant executeMakeView(ArrayType& array, const MakeView& makeView) {
using View = std::variant_alternative_t<TypeIndex, ArrayViewVariant>;
using Value = typename View::value_type;
constexpr auto Rank = View::rank();

if (array.datatype() == DataType::kind<Value>() && array.rank() == Rank) {
return makeView(array, Value{}, std::integral_constant<idx_t, Rank>{});
constexpr auto Const = std::is_const_v<typename View::value_type>;

if constexpr (std::is_const_v<ArrayType> == Const) {
using Value = typename View::non_const_value_type;
constexpr auto Rank = View::rank();
if (array.datatype() == DataType::kind<Value>() && array.rank() == Rank) {
return makeView(array, Value{}, std::integral_constant<int, Rank>{});
}
}

if constexpr (TypeIndex < std::variant_size_v<ArrayViewVariant> - 1) {
Expand Down Expand Up @@ -82,7 +85,7 @@ ArrayViewVariant make_host_view_variant(const Array& array) {
return makeHostViewVariantImpl(array);
}

ArrayViewVariant make_devive_view_variant(Array& array) {
ArrayViewVariant make_device_view_variant(Array& array) {
return makeDeviceViewVariantImpl(array);
}

Expand Down
68 changes: 32 additions & 36 deletions src/atlas/array/ArrayViewVariant.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <variant>

#include "atlas/array.h"
#include "eckit/utils/Overloaded.h"

namespace atlas {
namespace array {
Expand All @@ -24,60 +23,57 @@ template <typename... Ts>
struct Types {};

// Container struct for a list of integers.
template <idx_t... Is>
template <int... Is>
struct Ints {};

// Supported ArrayView value types.
constexpr auto Values = Types<float, double, int, long, unsigned long>{};

// Supported ArrayView ranks.
constexpr auto Ranks = Ints<1, 2, 3, 4, 5, 6, 7, 8, 9>{};

// Helper struct to build an ArrayView variant from a list of value types and
// and a list of ranks.
template <typename... Views>
struct VariantBuilder {
using type = std::variant<Views...>;

// Make a VariantBuilder struct with a fully populated Views... argument.
template <typename T, typename... Ts, idx_t... Is>
static constexpr auto make(Types<T, Ts...>, Ints<Is...>) {
using NewBuilder = VariantBuilder<Views..., ArrayView<T, Is>...,
ArrayView<const T, Is>...>;
if constexpr (sizeof...(Ts) > 0) {
return NewBuilder::make(Types<Ts...>{}, Ints<Is...>{});
} else {
return NewBuilder{};
}
}
template <typename...>
struct VariantHelper;

// Recursively construct ArrayView std::variant from types Ts and Ranks Is.
template <typename... ArrayViews, typename T, typename... Ts, int... Is>
struct VariantHelper<Types<ArrayViews...>, Types<T, Ts...>, Ints<Is...>> {
using type = typename VariantHelper<
Types<ArrayViews..., ArrayView<const T, Is>..., ArrayView<T, Is>...>,
Types<Ts...>, Ints<Is...>>::type;
};

// End recursion.
template <typename... ArrayViews, int... Is>
struct VariantHelper<Types<ArrayViews...>, Types<>, Ints<Is...>> {
using type = std::variant<ArrayViews...>;
};
constexpr auto variantHelper = VariantBuilder<>::make(Values, Ranks);

template <typename Values, typename Ranks>
using Variant = typename VariantHelper<Types<>, Values, Ranks>::type;

} // namespace detail

/// @brief Variant containing all supported ArrayView alternatives.
using ArrayViewVariant = decltype(detail::variantHelper)::type;
/// @brief Supported ArrayView value types.
using Values = detail::Types<float, double, int, long, unsigned long>;

/// @brief Supported ArrayView ranks.
using Ranks = detail::Ints<1, 2, 3, 4, 5, 6, 7, 8, 9>;

/// @brief Use overloaded pattern as visitor.
using eckit::Overloaded;
/// @brief Variant containing all supported ArrayView alternatives.
using ArrayViewVariant = detail::Variant<Values, Ranks>;

/// @brief Create an ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_view_variant(array::Array& array);
ArrayViewVariant make_view_variant(Array& array);

/// @brief Create a const ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_view_variant(const array::Array& array);
ArrayViewVariant make_view_variant(const Array& array);

/// @brief Create a host ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_host_view_variant(array::Array& array);
ArrayViewVariant make_host_view_variant(Array& array);

/// @brief Create a host const ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_host_view_variant(const array::Array& array);
ArrayViewVariant make_host_view_variant(const Array& array);

/// @brief Create a device ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_devive_view_variant(array::Array& array);
ArrayViewVariant make_device_view_variant(Array& array);

/// @brief Create a const device ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_device_view_variant(const array::Array& array);
ArrayViewVariant make_device_view_variant(const Array& array);

} // namespace array
} // namespace atlas
4 changes: 2 additions & 2 deletions src/tests/array/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ atlas_add_hic_test(
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT}
)

ecbuild_add_test( TARGET atlas_test_arrayviewviewvariant
SOURCES test_arrayviewvariant.cc
ecbuild_add_test( TARGET atlas_test_array_view_variant
SOURCES test_array_view_variant.cc
LIBS atlas
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT}
)
172 changes: 172 additions & 0 deletions src/tests/array/test_array_view_variant.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* (C) Crown Copyright 2024 Met Office
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
*/

#include <type_traits>
#include <variant>

#include "atlas/array.h"
#include "atlas/array/ArrayViewVariant.h"
#include "eckit/utils/Overloaded.h"
#include "tests/AtlasTestEnvironment.h"

namespace atlas {
namespace test {

using namespace array;

CASE("test variant assignment") {
auto array1 = array::ArrayT<float>(2);
auto array2 = array::ArrayT<double>(2, 3);
auto array3 = array::ArrayT<int>(2, 3, 4);
const auto& arrayRef = array1;

array1.allocateDevice();
array2.allocateDevice();
array3.allocateDevice();

auto view1 = make_view_variant(array1);
auto view2 = make_view_variant(array2);
auto view3 = make_view_variant(array3);
auto view4 = make_view_variant(arrayRef);

auto hostView1 = make_host_view_variant(array1);
auto hostView2 = make_host_view_variant(array2);
auto hostView3 = make_host_view_variant(array3);
auto hostView4 = make_host_view_variant(arrayRef);

auto deviceView1 = make_device_view_variant(array1);
auto deviceView2 = make_device_view_variant(array2);
auto deviceView3 = make_device_view_variant(array3);
auto deviceView4 = make_device_view_variant(arrayRef);

const auto visitVariants = [](auto& var1, auto& var2, auto var3, auto var4) {
std::visit(
[](auto&& view) {
using View = std::remove_reference_t<decltype(view)>;
EXPECT_EQ(View::rank(), 1);
EXPECT((std::is_same_v<typename View::value_type, float>));
},
var1);

std::visit(
[](auto&& view) {
using View = std::remove_reference_t<decltype(view)>;
EXPECT_EQ(View::rank(), 2);
EXPECT((std::is_same_v<typename View::value_type, double>));
},
var2);

std::visit(
[](auto&& view) {
using View = std::remove_reference_t<decltype(view)>;
EXPECT_EQ(View::rank(), 3);
EXPECT((std::is_same_v<typename View::value_type, int>));
},
var3);

std::visit(
[](auto&& view) {
using View = std::remove_reference_t<decltype(view)>;
EXPECT_EQ(View::rank(), 1);
EXPECT((std::is_same_v<typename View::value_type, const float>));
},
var4);
};

visitVariants(view1, view2, view3, view4);
visitVariants(hostView1, hostView2, hostView3, hostView4);
visitVariants(deviceView1, deviceView2, deviceView3, deviceView4);
}

template <typename View>
constexpr auto Rank = std::decay_t<View>::rank();

CASE("test std::visit") {
auto array1 = ArrayT<int>(10);
make_view<int, 1>(array1).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});

auto array2 = ArrayT<int>(5, 2);
make_view<int, 2>(array2).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});

const auto testRank1 = [](auto&& view) {
using View = std::decay_t<decltype(view)>;
EXPECT_EQ(View::rank(), 1);
using Value = typename View::value_type;
EXPECT((std::is_same_v<Value, int>));

for (auto i = size_t{0}; i < view.size(); ++i) {
EXPECT_EQ(view(i), static_cast<Value>(i));
}
};

const auto testRank2 = [](auto&& view) {
using View = std::decay_t<decltype(view)>;
EXPECT_EQ(View::rank(), 2);
using Value = typename View::value_type;
EXPECT((std::is_same_v<Value, int>));

auto testValue = int{0};
for (auto i = idx_t{0}; i < view.shape(0); ++i) {
for (auto j = idx_t{0}; j < view.shape(1); ++j) {
EXPECT_EQ(view(i, j), static_cast<Value>(testValue++));
}
}
};

const auto var1 = make_view_variant(array1);
const auto var2 = make_view_variant(array2);

SECTION("demonstrate 'if constexpr' pattern") {
auto rank1Tested = false;
auto rank2Tested = false;

const auto visitor = [&](auto&& view) {
if constexpr (Rank<decltype(view)> == 1) {
testRank1(view);
rank1Tested = true;
}
if constexpr (Rank<decltype(view)> == 2) {
testRank2(view);
rank2Tested = true;
}
};

std::visit(visitor, var1);
EXPECT(rank1Tested);
std::visit(visitor, var2);
EXPECT(rank2Tested);
}

SECTION("demonstrate 'overloaded' pattern") {
// Note, SFINAE can be eliminated using concepts and explicit lambda
// templates in C++20.
auto rank1Tested = false;
auto rank2Tested = false;
const auto visitor = eckit::Overloaded{
[&](auto&& view) -> std::enable_if_t<Rank<decltype(view)> == 1> {
testRank1(view);
rank1Tested = true;
},
[&](auto&& view) -> std::enable_if_t<Rank<decltype(view)> == 2> {
testRank2(view);
rank2Tested = true;
},
[](auto&& view) -> std::enable_if_t<Rank<decltype(view)> >= 3> {
// do nothing.
}};

std::visit(visitor, var1);
EXPECT(rank1Tested);
std::visit(visitor, var2);
EXPECT(rank2Tested);
}
}

} // namespace test
} // namespace atlas

int main(int argc, char** argv) { return atlas::test::run(argc, argv); }
Loading

0 comments on commit cbf7818

Please sign in to comment.