From cbf7818df1da51fead09bbbadef0620de28ee55d Mon Sep 17 00:00:00 2001 From: odlomax Date: Fri, 13 Sep 2024 16:41:27 +0100 Subject: [PATCH] Updated test. --- src/atlas/array/ArrayViewVariant.cc | 15 +- src/atlas/array/ArrayViewVariant.h | 68 ++++---- src/tests/array/CMakeLists.txt | 4 +- src/tests/array/test_array_view_variant.cc | 172 +++++++++++++++++++++ src/tests/array/test_arrayviewvariant.cc | 110 ------------- 5 files changed, 215 insertions(+), 154 deletions(-) create mode 100644 src/tests/array/test_array_view_variant.cc delete mode 100644 src/tests/array/test_arrayviewvariant.cc diff --git a/src/atlas/array/ArrayViewVariant.cc b/src/atlas/array/ArrayViewVariant.cc index fb0db8fa5..e131214ae 100644 --- a/src/atlas/array/ArrayViewVariant.cc +++ b/src/atlas/array/ArrayViewVariant.cc @@ -23,11 +23,14 @@ namespace { template ArrayViewVariant executeMakeView(ArrayType& array, const MakeView& makeView) { using View = std::variant_alternative_t; - using Value = typename View::value_type; - constexpr auto Rank = View::rank(); - - if (array.datatype() == DataType::kind() && array.rank() == Rank) { - return makeView(array, Value{}, std::integral_constant{}); + constexpr auto Const = std::is_const_v; + + if constexpr (std::is_const_v == Const) { + using Value = typename View::non_const_value_type; + constexpr auto Rank = View::rank(); + if (array.datatype() == DataType::kind() && array.rank() == Rank) { + return makeView(array, Value{}, std::integral_constant{}); + } } if constexpr (TypeIndex < std::variant_size_v - 1) { @@ -82,7 +85,7 @@ ArrayViewVariant make_host_view_variant(const Array& array) { return makeHostViewVariantImpl(array); } -ArrayViewVariant make_devive_view_variant(Array& array) { +ArrayViewVariant make_device_view_variant(Array& array) { return makeDeviceViewVariantImpl(array); } diff --git a/src/atlas/array/ArrayViewVariant.h b/src/atlas/array/ArrayViewVariant.h index f369636fa..f2d6d94f3 100644 --- a/src/atlas/array/ArrayViewVariant.h +++ b/src/atlas/array/ArrayViewVariant.h @@ -10,7 +10,6 @@ #include #include "atlas/array.h" -#include "eckit/utils/Overloaded.h" namespace atlas { namespace array { @@ -24,60 +23,57 @@ template struct Types {}; // Container struct for a list of integers. -template +template struct Ints {}; -// Supported ArrayView value types. -constexpr auto Values = Types{}; - -// Supported ArrayView ranks. -constexpr auto Ranks = Ints<1, 2, 3, 4, 5, 6, 7, 8, 9>{}; - -// Helper struct to build an ArrayView variant from a list of value types and -// and a list of ranks. -template -struct VariantBuilder { - using type = std::variant; - - // Make a VariantBuilder struct with a fully populated Views... argument. - template - static constexpr auto make(Types, Ints) { - using NewBuilder = VariantBuilder..., - ArrayView...>; - if constexpr (sizeof...(Ts) > 0) { - return NewBuilder::make(Types{}, Ints{}); - } else { - return NewBuilder{}; - } - } +template +struct VariantHelper; + +// Recursively construct ArrayView std::variant from types Ts and Ranks Is. +template +struct VariantHelper, Types, Ints> { + using type = typename VariantHelper< + Types..., ArrayView...>, + Types, Ints>::type; +}; + +// End recursion. +template +struct VariantHelper, Types<>, Ints> { + using type = std::variant; }; -constexpr auto variantHelper = VariantBuilder<>::make(Values, Ranks); + +template +using Variant = typename VariantHelper, Values, Ranks>::type; } // namespace detail -/// @brief Variant containing all supported ArrayView alternatives. -using ArrayViewVariant = decltype(detail::variantHelper)::type; +/// @brief Supported ArrayView value types. +using Values = detail::Types; + +/// @brief Supported ArrayView ranks. +using Ranks = detail::Ints<1, 2, 3, 4, 5, 6, 7, 8, 9>; -/// @brief Use overloaded pattern as visitor. -using eckit::Overloaded; +/// @brief Variant containing all supported ArrayView alternatives. +using ArrayViewVariant = detail::Variant; /// @brief Create an ArrayView and assign to an ArrayViewVariant. -ArrayViewVariant make_view_variant(array::Array& array); +ArrayViewVariant make_view_variant(Array& array); /// @brief Create a const ArrayView and assign to an ArrayViewVariant. -ArrayViewVariant make_view_variant(const array::Array& array); +ArrayViewVariant make_view_variant(const Array& array); /// @brief Create a host ArrayView and assign to an ArrayViewVariant. -ArrayViewVariant make_host_view_variant(array::Array& array); +ArrayViewVariant make_host_view_variant(Array& array); /// @brief Create a host const ArrayView and assign to an ArrayViewVariant. -ArrayViewVariant make_host_view_variant(const array::Array& array); +ArrayViewVariant make_host_view_variant(const Array& array); /// @brief Create a device ArrayView and assign to an ArrayViewVariant. -ArrayViewVariant make_devive_view_variant(array::Array& array); +ArrayViewVariant make_device_view_variant(Array& array); /// @brief Create a const device ArrayView and assign to an ArrayViewVariant. -ArrayViewVariant make_device_view_variant(const array::Array& array); +ArrayViewVariant make_device_view_variant(const Array& array); } // namespace array } // namespace atlas diff --git a/src/tests/array/CMakeLists.txt b/src/tests/array/CMakeLists.txt index eda04a6d1..880e46ba6 100644 --- a/src/tests/array/CMakeLists.txt +++ b/src/tests/array/CMakeLists.txt @@ -81,8 +81,8 @@ atlas_add_hic_test( ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} ) -ecbuild_add_test( TARGET atlas_test_arrayviewviewvariant - SOURCES test_arrayviewvariant.cc +ecbuild_add_test( TARGET atlas_test_array_view_variant + SOURCES test_array_view_variant.cc LIBS atlas ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} ) diff --git a/src/tests/array/test_array_view_variant.cc b/src/tests/array/test_array_view_variant.cc new file mode 100644 index 000000000..6d171a0d8 --- /dev/null +++ b/src/tests/array/test_array_view_variant.cc @@ -0,0 +1,172 @@ +/* + * (C) Crown Copyright 2024 Met Office + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#include +#include + +#include "atlas/array.h" +#include "atlas/array/ArrayViewVariant.h" +#include "eckit/utils/Overloaded.h" +#include "tests/AtlasTestEnvironment.h" + +namespace atlas { +namespace test { + +using namespace array; + +CASE("test variant assignment") { + auto array1 = array::ArrayT(2); + auto array2 = array::ArrayT(2, 3); + auto array3 = array::ArrayT(2, 3, 4); + const auto& arrayRef = array1; + + array1.allocateDevice(); + array2.allocateDevice(); + array3.allocateDevice(); + + auto view1 = make_view_variant(array1); + auto view2 = make_view_variant(array2); + auto view3 = make_view_variant(array3); + auto view4 = make_view_variant(arrayRef); + + auto hostView1 = make_host_view_variant(array1); + auto hostView2 = make_host_view_variant(array2); + auto hostView3 = make_host_view_variant(array3); + auto hostView4 = make_host_view_variant(arrayRef); + + auto deviceView1 = make_device_view_variant(array1); + auto deviceView2 = make_device_view_variant(array2); + auto deviceView3 = make_device_view_variant(array3); + auto deviceView4 = make_device_view_variant(arrayRef); + + const auto visitVariants = [](auto& var1, auto& var2, auto var3, auto var4) { + std::visit( + [](auto&& view) { + using View = std::remove_reference_t; + EXPECT_EQ(View::rank(), 1); + EXPECT((std::is_same_v)); + }, + var1); + + std::visit( + [](auto&& view) { + using View = std::remove_reference_t; + EXPECT_EQ(View::rank(), 2); + EXPECT((std::is_same_v)); + }, + var2); + + std::visit( + [](auto&& view) { + using View = std::remove_reference_t; + EXPECT_EQ(View::rank(), 3); + EXPECT((std::is_same_v)); + }, + var3); + + std::visit( + [](auto&& view) { + using View = std::remove_reference_t; + EXPECT_EQ(View::rank(), 1); + EXPECT((std::is_same_v)); + }, + var4); + }; + + visitVariants(view1, view2, view3, view4); + visitVariants(hostView1, hostView2, hostView3, hostView4); + visitVariants(deviceView1, deviceView2, deviceView3, deviceView4); +} + +template +constexpr auto Rank = std::decay_t::rank(); + +CASE("test std::visit") { + auto array1 = ArrayT(10); + make_view(array1).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + auto array2 = ArrayT(5, 2); + make_view(array2).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + const auto testRank1 = [](auto&& view) { + using View = std::decay_t; + EXPECT_EQ(View::rank(), 1); + using Value = typename View::value_type; + EXPECT((std::is_same_v)); + + for (auto i = size_t{0}; i < view.size(); ++i) { + EXPECT_EQ(view(i), static_cast(i)); + } + }; + + const auto testRank2 = [](auto&& view) { + using View = std::decay_t; + EXPECT_EQ(View::rank(), 2); + using Value = typename View::value_type; + EXPECT((std::is_same_v)); + + auto testValue = int{0}; + for (auto i = idx_t{0}; i < view.shape(0); ++i) { + for (auto j = idx_t{0}; j < view.shape(1); ++j) { + EXPECT_EQ(view(i, j), static_cast(testValue++)); + } + } + }; + + const auto var1 = make_view_variant(array1); + const auto var2 = make_view_variant(array2); + + SECTION("demonstrate 'if constexpr' pattern") { + auto rank1Tested = false; + auto rank2Tested = false; + + const auto visitor = [&](auto&& view) { + if constexpr (Rank == 1) { + testRank1(view); + rank1Tested = true; + } + if constexpr (Rank == 2) { + testRank2(view); + rank2Tested = true; + } + }; + + std::visit(visitor, var1); + EXPECT(rank1Tested); + std::visit(visitor, var2); + EXPECT(rank2Tested); + } + + SECTION("demonstrate 'overloaded' pattern") { + // Note, SFINAE can be eliminated using concepts and explicit lambda + // templates in C++20. + auto rank1Tested = false; + auto rank2Tested = false; + const auto visitor = eckit::Overloaded{ + [&](auto&& view) -> std::enable_if_t == 1> { + testRank1(view); + rank1Tested = true; + }, + [&](auto&& view) -> std::enable_if_t == 2> { + testRank2(view); + rank2Tested = true; + }, + [](auto&& view) -> std::enable_if_t >= 3> { + // do nothing. + }}; + + std::visit(visitor, var1); + EXPECT(rank1Tested); + std::visit(visitor, var2); + EXPECT(rank2Tested); + } +} + +} // namespace test +} // namespace atlas + +int main(int argc, char** argv) { return atlas::test::run(argc, argv); } diff --git a/src/tests/array/test_arrayviewvariant.cc b/src/tests/array/test_arrayviewvariant.cc deleted file mode 100644 index 2ab24bd2c..000000000 --- a/src/tests/array/test_arrayviewvariant.cc +++ /dev/null @@ -1,110 +0,0 @@ -/* - * (C) Crown Copyright 2024 Met Office - * - * This software is licensed under the terms of the Apache Licence Version 2.0 - * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. - */ - -#include -#include -#include - -#include "atlas/array.h" -#include "atlas/array/ArrayViewVariant.h" -#include "tests/AtlasTestEnvironment.h" - -namespace atlas { -namespace test { - -using namespace array; - -CASE("test visit") { - auto arr1 = array::ArrayT(2); - auto arr2 = array::ArrayT(2, 3); - auto arr3 = array::ArrayT(2, 3, 4); - - const auto var1 = make_view_variant(arr1); - const auto var2 = make_view_variant(arr2); - const auto var3 = make_view_variant(arr3); - - std::visit( - [](auto&& view) { - using View = std::remove_reference_t; - EXPECT_EQ(View::rank(), 1); - EXPECT((std::is_same_v)); - }, - var1); - - std::visit( - [](auto&& view) { - using View = std::remove_reference_t; - EXPECT_EQ(View::rank(), 2); - EXPECT((std::is_same_v)); - }, - var2); - - std::visit( - [](auto&& view) { - using View = std::remove_reference_t; - EXPECT_EQ(View::rank(), 3); - EXPECT((std::is_same_v)); - }, - var3); -} - -template -constexpr auto Rank = std::decay_t::rank(); - -CASE("test array view data") { - auto arr = ArrayT(10); - make_view(arr).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - - const auto& arrRef = arr; - const auto var = make_view_variant(arrRef); - - const auto visitor = Overloaded{ - [](auto&& view) -> std::enable_if_t == 1> { - using View = std::decay_t; - EXPECT_EQ(View::rank(), 1); - using Value = typename View::value_type; - EXPECT((std::is_same_v)); - - for (auto i = size_t{0}; i < view.size(); ++i) { - EXPECT_EQ(view(i), static_cast(i)); - } - }, - [](auto&& view) -> std::enable_if_t != 1> { - // do nothing. - }}; - - std::visit(visitor, var); -} - -CASE("test instantiation") { - auto arr = array::ArrayT(1); - const auto constArr = array::ArrayT(1); - - // SECTION("default variants") { - // auto var = make_view_variant(arr); - // auto constVar = make_view_variant(constArr); - - // using VarType = std::variant, ArrayView, - // ArrayView, ArrayView, - // ArrayView, ArrayView>; - - // using ConstVarType = - // std::variant, ArrayView, - // ArrayView, ArrayView, - // ArrayView, ArrayView>; - - // EXPECT((std::is_same_v)); - // EXPECT((std::is_same_v)); - // } -} - -} // namespace test -} // namespace atlas - -int main(int argc, char** argv) { return atlas::test::run(argc, argv); }