diff --git a/core/base/iterator_factory.hpp b/core/base/iterator_factory.hpp index 3d224836b1a..938d705b04d 100644 --- a/core/base/iterator_factory.hpp +++ b/core/base/iterator_factory.hpp @@ -24,6 +24,234 @@ template class zip_iterator; +template +class zip_iterator_reference; + + +template +class device_tuple; + + +} // namespace detail +} // namespace gko + + +// structured binding specializations for device_tuple, zip_iterator_reference +namespace std { + + +template +struct tuple_size> + : integral_constant {}; + + +template +struct tuple_element> { + using type = typename tuple_element>::type; +}; + + +template +struct tuple_size> + : integral_constant {}; + + +template +struct tuple_element> { + using type = typename iterator_traits< + typename tuple_element>::type>::reference; +}; + + +} // namespace std + + +namespace gko { + + +/** std::get reimplementation for device_tuple. */ +template +constexpr typename std::tuple_element>::type& +get(detail::device_tuple& tuple); + + +/** std::get reimplementation for const device_tuple. */ +template +constexpr const typename std::tuple_element>::type& +get(const detail::device_tuple& tuple); + + +namespace detail { + + +/** simplified constexpr std::tuple reimplementation for use in device code. */ +template +class device_tuple { +public: + /** Constructs a device tuple from its elements. */ + constexpr explicit device_tuple(T value, Ts... others) + : value_{value}, other_{others...} + {} + + device_tuple() = default; + + /** + * Copy-assigns a tuple. + * This is necessary to make tuples of references work, which normally cause + * the impliciy copy-assignment operator to be deleted. + */ + constexpr device_tuple& operator=(const device_tuple& other) + { + value_ = other.value_; + other_ = other.other_; + return *this; + } + + /** @return the index-th element in the tuple. */ + template + constexpr typename std::tuple_element::type& get() + { + if constexpr (index == 0) { + return value_; + } else { + return other_.template get(); + } + } + + /** @return the index-th element in the const tuple. */ + template + constexpr const typename std::tuple_element::type& + get() const + { + if constexpr (index == 0) { + return value_; + } else { + return other_.template get(); + } + } + + // comparison operators + constexpr friend bool operator<(const device_tuple& lhs, + const device_tuple& rhs) + { + return lhs.value_ < rhs.value_ || + (lhs.value_ == rhs.value_ && lhs.other_ < rhs.other_); + } + + constexpr friend bool operator>(const device_tuple& lhs, + const device_tuple& rhs) + { + return rhs < lhs; + } + + constexpr friend bool operator>=(const device_tuple& lhs, + const device_tuple& rhs) + { + return !(lhs < rhs); + } + + constexpr friend bool operator<=(const device_tuple& lhs, + const device_tuple& rhs) + { + return !(lhs > rhs); + } + + constexpr friend bool operator==(const device_tuple& lhs, + const device_tuple& rhs) + { + return lhs.value_ == rhs.value_ && lhs.other_ == rhs.other_; + } + + constexpr friend bool operator!=(const device_tuple& lhs, + const device_tuple& rhs) + { + return !(lhs == rhs); + } + +private: + T value_; + device_tuple other_; +}; + + +template +class device_tuple { +public: + /** Constructs a device tuple from its elements. */ + constexpr explicit device_tuple(T value) : value_{value} {} + + device_tuple() = default; + + /** + * Copy-assigns a tuple. + * This is necessary to make tuples of references work, which normally cause + * the impliciy copy-assignment operator to be deleted. + */ + constexpr device_tuple& operator=(const device_tuple& other) + { + value_ = other.value_; + return *this; + } + + /** @return the index-th element in the tuple. */ + template + constexpr T& get() + { + static_assert(index == 0, "invalid index"); + return value_; + } + + /** @return the index-th element in the const tuple. */ + template + constexpr const T& get() const + { + static_assert(index == 0, "invalid index"); + return value_; + } + + // comparison operators + constexpr friend bool operator<(const device_tuple& lhs, + const device_tuple& rhs) + { + return lhs.value_ < rhs.value_; + } + + constexpr friend bool operator>(const device_tuple& lhs, + const device_tuple& rhs) + { + return rhs < lhs; + } + + constexpr friend bool operator>=(const device_tuple& lhs, + const device_tuple& rhs) + { + return !(lhs < rhs); + } + + constexpr friend bool operator<=(const device_tuple& lhs, + const device_tuple& rhs) + { + return !(lhs > rhs); + } + + constexpr friend bool operator==(const device_tuple& lhs, + const device_tuple& rhs) + { + return lhs.value_ == rhs.value_; + } + + constexpr friend bool operator!=(const device_tuple& lhs, + const device_tuple& rhs) + { + return !(lhs == rhs); + } + +private: + T value_; +}; + + /** * A reference-like type pointing to a tuple of elements originating from a * tuple of iterators. A few caveats related to its use: @@ -45,45 +273,51 @@ class zip_iterator; */ template class zip_iterator_reference - : public std::tuple< + : public device_tuple< typename std::iterator_traits::reference...> { using ref_tuple_type = - std::tuple::reference...>; + device_tuple::reference...>; using value_type = - std::tuple::value_type...>; + device_tuple::value_type...>; using index_sequence = std::index_sequence_for; friend class zip_iterator; template - value_type cast_impl(std::index_sequence) const + constexpr value_type cast_impl(std::index_sequence) const { // gcc 5 throws error as using uninitialized array // std::tuple t = { 1, '2' }; is not allowed. // converting to 'std::tuple<...>' from initializer list would use // explicit constructor - return value_type(std::get(*this)...); + return value_type(get(*this)...); } template - void assign_impl(std::index_sequence, const value_type& other) + constexpr void assign_impl(std::index_sequence, + const value_type& other) { (void)std::initializer_list{ - (std::get(*this) = std::get(other), 0)...}; + (get(*this) = get(other), 0)...}; } - zip_iterator_reference(Iterators... it) : ref_tuple_type{*it...} {} + constexpr explicit zip_iterator_reference(Iterators... it) + : ref_tuple_type{*it...} + {} public: - operator value_type() const { return cast_impl(index_sequence{}); } + constexpr operator value_type() const + { + return cast_impl(index_sequence{}); + } - zip_iterator_reference& operator=(const value_type& other) + constexpr zip_iterator_reference& operator=(const value_type& other) { assign_impl(index_sequence{}, other); return *this; } - value_type copy() const { return *this; } + constexpr value_type copy() const { return *this; } }; @@ -123,153 +357,156 @@ class zip_iterator { public: using difference_type = std::ptrdiff_t; using value_type = - std::tuple::value_type...>; + device_tuple::value_type...>; using pointer = value_type*; using reference = zip_iterator_reference; using iterator_category = std::random_access_iterator_tag; using index_sequence = std::index_sequence_for; - explicit zip_iterator() = default; + constexpr zip_iterator() = default; - explicit zip_iterator(Iterators... its) : iterators_{its...} {} + constexpr explicit zip_iterator(Iterators... its) : iterators_{its...} {} - zip_iterator& operator+=(difference_type i) + constexpr zip_iterator& operator+=(difference_type i) { forall([i](auto& it) { it += i; }); return *this; } - zip_iterator& operator-=(difference_type i) + constexpr zip_iterator& operator-=(difference_type i) { forall([i](auto& it) { it -= i; }); return *this; } - zip_iterator& operator++() + constexpr zip_iterator& operator++() { forall([](auto& it) { it++; }); return *this; } - zip_iterator operator++(int) + constexpr zip_iterator operator++(int) { auto tmp = *this; ++(*this); return tmp; } - zip_iterator& operator--() + constexpr zip_iterator& operator--() { forall([](auto& it) { it--; }); return *this; } - zip_iterator operator--(int) + constexpr zip_iterator operator--(int) { auto tmp = *this; --(*this); return tmp; } - zip_iterator operator+(difference_type i) const + constexpr zip_iterator operator+(difference_type i) const { auto tmp = *this; tmp += i; return tmp; } - friend zip_iterator operator+(difference_type i, const zip_iterator& iter) + constexpr friend zip_iterator operator+(difference_type i, + const zip_iterator& iter) { return iter + i; } - zip_iterator operator-(difference_type i) const + constexpr zip_iterator operator-(difference_type i) const { auto tmp = *this; tmp -= i; return tmp; } - difference_type operator-(const zip_iterator& other) const + constexpr difference_type operator-(const zip_iterator& other) const { return forall_check_consistent( other, [](const auto& a, const auto& b) { return a - b; }); } - reference operator*() const + constexpr reference operator*() const { return deref_impl(std::index_sequence_for{}); } - reference operator[](difference_type i) const { return *(*this + i); } + constexpr reference operator[](difference_type i) const + { + return *(*this + i); + } - bool operator==(const zip_iterator& other) const + constexpr bool operator==(const zip_iterator& other) const { return forall_check_consistent( other, [](const auto& a, const auto& b) { return a == b; }); } - bool operator!=(const zip_iterator& other) const + constexpr bool operator!=(const zip_iterator& other) const { return !(*this == other); } - bool operator<(const zip_iterator& other) const + constexpr bool operator<(const zip_iterator& other) const { return forall_check_consistent( other, [](const auto& a, const auto& b) { return a < b; }); } - bool operator<=(const zip_iterator& other) const + constexpr bool operator<=(const zip_iterator& other) const { return forall_check_consistent( other, [](const auto& a, const auto& b) { return a <= b; }); } - bool operator>(const zip_iterator& other) const + constexpr bool operator>(const zip_iterator& other) const { return !(*this <= other); } - bool operator>=(const zip_iterator& other) const + constexpr bool operator>=(const zip_iterator& other) const { return !(*this < other); } private: template - reference deref_impl(std::index_sequence) const + constexpr reference deref_impl(std::index_sequence) const { - return reference{std::get(iterators_)...}; + return reference{get(iterators_)...}; } template - void forall(Functor fn) + constexpr void forall(Functor fn) { forall_impl(fn, index_sequence{}); } template - void forall_impl(Functor fn, std::index_sequence) + constexpr void forall_impl(Functor fn, std::index_sequence) { - (void)std::initializer_list{ - (fn(std::get(iterators_)), 0)...}; + (void)std::initializer_list{(fn(get(iterators_)), 0)...}; } template - void forall_impl(const zip_iterator& other, Functor fn, - std::index_sequence) const + constexpr void forall_impl(const zip_iterator& other, Functor fn, + std::index_sequence) const { (void)std::initializer_list{ - (fn(std::get(iterators_), std::get(other.iterators_)), - 0)...}; + (fn(get(iterators_), get(other.iterators_)), 0)...}; } template - auto forall_check_consistent(const zip_iterator& other, Functor fn) const + constexpr auto forall_check_consistent(const zip_iterator& other, + Functor fn) const { - auto it = std::get<0>(iterators_); - auto other_it = std::get<0>(other.iterators_); + auto it = get<0>(iterators_); + auto other_it = get<0>(other.iterators_); auto result = fn(it, other_it); forall_impl( other, [&](auto a, auto b) { assert(it - other_it == a - b); }, @@ -277,12 +514,13 @@ class zip_iterator { return result; } - std::tuple iterators_; + device_tuple iterators_; }; template -zip_iterator...> make_zip_iterator(Iterators&&... it) +constexpr zip_iterator...> make_zip_iterator( + Iterators&&... it) { return zip_iterator...>{ std::forward(it)...}; @@ -305,8 +543,8 @@ zip_iterator...> make_zip_iterator(Iterators&&... it) * @tparam Iterators the iterator types inside the corresponding zip_iterator */ template -void swap(zip_iterator_reference a, - zip_iterator_reference b) +constexpr void swap(zip_iterator_reference a, + zip_iterator_reference b) { auto tmp = a.copy(); a = b; @@ -318,8 +556,8 @@ void swap(zip_iterator_reference a, * @copydoc swap(zip_iterator_reference, zip_iterator_reference) */ template -void swap(typename zip_iterator::value_type& a, - zip_iterator_reference b) +constexpr void swap(typename zip_iterator::value_type& a, + zip_iterator_reference b) { auto tmp = a; a = b; @@ -331,8 +569,8 @@ void swap(typename zip_iterator::value_type& a, * @copydoc swap(zip_iterator_reference, zip_iterator_reference) */ template -void swap(zip_iterator_reference a, - typename zip_iterator::value_type& b) +constexpr void swap(zip_iterator_reference a, + typename zip_iterator::value_type& b) { auto tmp = a.copy(); a = b; @@ -468,6 +706,25 @@ permute_iterator make_permute_iterator( } // namespace detail + + +template +constexpr typename std::tuple_element>::type& +get(detail::device_tuple& tuple) +{ + return tuple.template get(); +} + + +template +constexpr const typename std::tuple_element>::type& +get(const detail::device_tuple& tuple) +{ + return tuple.template get(); +} + + } // namespace gko diff --git a/core/test/base/iterator_factory.cpp b/core/test/base/iterator_factory.cpp index 42ddff343c0..c4dc30bf219 100644 --- a/core/test/base/iterator_factory.cpp +++ b/core/test/base/iterator_factory.cpp @@ -156,6 +156,7 @@ TYPED_TEST(ZipIterator, IteratorReferenceOperatorSmaller2) TYPED_TEST(ZipIterator, IncreasingIterator) { + using gko::get; using index_type = typename TestFixture::index_type; using value_type = typename TestFixture::value_type; std::vector vec1{this->reversed_index}; @@ -182,8 +183,8 @@ TYPED_TEST(ZipIterator, IncreasingIterator) ASSERT_TRUE(increment_pre_2 == increment_post_2); ASSERT_TRUE(begin == increment_post_test++); ASSERT_TRUE(begin + 1 == ++increment_pre_test); - ASSERT_TRUE(std::get<0>(*plus_2) == vec1[2]); - ASSERT_TRUE(std::get<1>(*plus_2) == vec2[2]); + ASSERT_TRUE(get<0>(*plus_2) == vec1[2]); + ASSERT_TRUE(get<1>(*plus_2) == vec2[2]); // check other comparison operators and difference std::vector> its{ begin, @@ -257,6 +258,7 @@ TYPED_TEST(ZipIterator, IncompatibleIteratorDeathTest) TYPED_TEST(ZipIterator, DecreasingIterator) { + using gko::get; using index_type = typename TestFixture::index_type; using value_type = typename TestFixture::value_type; std::vector vec1{this->reversed_index}; @@ -280,13 +282,14 @@ TYPED_TEST(ZipIterator, DecreasingIterator) ASSERT_TRUE(decrement_pre_2 == decrement_post_2); ASSERT_TRUE(iter == decrement_post_test--); ASSERT_TRUE(iter - 1 == --decrement_pre_test); - ASSERT_TRUE(std::get<0>(*minus_2) == vec1[3]); - ASSERT_TRUE(std::get<1>(*minus_2) == vec2[3]); + ASSERT_TRUE(get<0>(*minus_2) == vec1[3]); + ASSERT_TRUE(get<1>(*minus_2) == vec2[3]); } TYPED_TEST(ZipIterator, CorrectDereferencing) { + using gko::get; using index_type_it = typename TestFixture::index_type; using value_type_it = typename TestFixture::value_type; std::vector vec1{this->reversed_index}; @@ -299,10 +302,10 @@ TYPED_TEST(ZipIterator, CorrectDereferencing) auto to_test_ref = *(begin + element_to_test); value_type to_test_pair = to_test_ref; // Testing implicit conversion - ASSERT_TRUE(std::get<0>(to_test_pair) == vec1[element_to_test]); - ASSERT_TRUE(std::get<0>(to_test_pair) == std::get<0>(to_test_ref)); - ASSERT_TRUE(std::get<1>(to_test_pair) == vec2[element_to_test]); - ASSERT_TRUE(std::get<1>(to_test_pair) == std::get<1>(to_test_ref)); + ASSERT_TRUE(get<0>(to_test_pair) == vec1[element_to_test]); + ASSERT_TRUE(get<0>(to_test_pair) == get<0>(to_test_ref)); + ASSERT_TRUE(get<1>(to_test_pair) == vec2[element_to_test]); + ASSERT_TRUE(get<1>(to_test_pair) == get<1>(to_test_ref)); } diff --git a/omp/distributed/partition_helpers_kernels.cpp b/omp/distributed/partition_helpers_kernels.cpp index ceae3e17679..a3dfa8fdef4 100644 --- a/omp/distributed/partition_helpers_kernels.cpp +++ b/omp/distributed/partition_helpers_kernels.cpp @@ -27,10 +27,9 @@ void sort_by_range_start( range_start_ends.get_data() + 1, [](const auto i) { return 2 * i; }); auto sort_it = detail::make_zip_iterator(start_it, end_it, part_ids_d); // TODO: use TBB or parallel std with c++17 - std::stable_sort(sort_it, sort_it + num_parts, - [](const auto& a, const auto& b) { - return std::get<0>(a) < std::get<0>(b); - }); + std::stable_sort( + sort_it, sort_it + num_parts, + [](const auto& a, const auto& b) { return get<0>(a) < get<0>(b); }); } GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE( diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index 09d1465896b..8e47caef520 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -1155,9 +1155,8 @@ void sort_by_column_index(std::shared_ptr exec, auto row_nnz = row_ptrs[i + 1] - start_row_idx; auto it = detail::make_zip_iterator(col_idxs + start_row_idx, values + start_row_idx); - std::sort(it, it + row_nnz, [](auto t1, auto t2) { - return std::get<0>(t1) < std::get<0>(t2); - }); + std::sort(it, it + row_nnz, + [](auto t1, auto t2) { return get<0>(t1) < get<0>(t2); }); } } diff --git a/omp/matrix/fbcsr_kernels.cpp b/omp/matrix/fbcsr_kernels.cpp index db60d85db79..a6342034a56 100644 --- a/omp/matrix/fbcsr_kernels.cpp +++ b/omp/matrix/fbcsr_kernels.cpp @@ -398,9 +398,8 @@ void sort_by_column_index_impl( std::vector col_permute(nbnz_brow); std::iota(col_permute.begin(), col_permute.end(), 0); auto it = detail::make_zip_iterator(brow_col_idxs, col_permute.data()); - std::sort(it, it + nbnz_brow, [](auto a, auto b) { - return std::get<0>(a) < std::get<0>(b); - }); + std::sort(it, it + nbnz_brow, + [](auto a, auto b) { return get<0>(a) < get<0>(b); }); std::vector oldvalues(nbnz_brow * bs2); std::copy(brow_vals, brow_vals + nbnz_brow * bs2, oldvalues.begin()); diff --git a/omp/multigrid/pgm_kernels.cpp b/omp/multigrid/pgm_kernels.cpp index 9d2aa047cc4..4c824a0140b 100644 --- a/omp/multigrid/pgm_kernels.cpp +++ b/omp/multigrid/pgm_kernels.cpp @@ -43,8 +43,7 @@ void sort_row_major(std::shared_ptr exec, size_type nnz, { auto it = detail::make_zip_iterator(row_idxs, col_idxs, vals); std::stable_sort(it, it + nnz, [](auto a, auto b) { - return std::tie(std::get<0>(a), std::get<1>(a)) < - std::tie(std::get<0>(b), std::get<1>(b)); + return std::tie(get<0>(a), get<1>(a)) < std::tie(get<0>(b), get<1>(b)); }); } diff --git a/reference/distributed/partition_helpers_kernels.cpp b/reference/distributed/partition_helpers_kernels.cpp index b57daab2eaa..0307974f278 100644 --- a/reference/distributed/partition_helpers_kernels.cpp +++ b/reference/distributed/partition_helpers_kernels.cpp @@ -26,10 +26,9 @@ void sort_by_range_start( auto end_it = detail::make_permute_iterator( range_start_ends.get_data() + 1, [](const auto i) { return 2 * i; }); auto sort_it = detail::make_zip_iterator(start_it, end_it, part_ids_d); - std::stable_sort(sort_it, sort_it + num_parts, - [](const auto& a, const auto& b) { - return std::get<0>(a) < std::get<0>(b); - }); + std::stable_sort( + sort_it, sort_it + num_parts, + [](const auto& a, const auto& b) { return get<0>(a) < get<0>(b); }); } GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE( @@ -51,9 +50,9 @@ void check_consecutive_ranges(std::shared_ptr exec, auto range_it = detail::make_zip_iterator(start_it, end_it); if (num_parts) { - result = std::all_of( - range_it, range_it + num_parts - 1, - [](const auto& r) { return std::get<0>(r) == std::get<1>(r); }); + result = + std::all_of(range_it, range_it + num_parts - 1, + [](const auto& r) { return get<0>(r) == get<1>(r); }); } else { result = true; } diff --git a/reference/matrix/csr_kernels.cpp b/reference/matrix/csr_kernels.cpp index f7e2fab4411..be97da442a1 100644 --- a/reference/matrix/csr_kernels.cpp +++ b/reference/matrix/csr_kernels.cpp @@ -1128,9 +1128,8 @@ void sort_by_column_index(std::shared_ptr exec, auto row_nnz = row_ptrs[i + 1] - start_row_idx; auto it = detail::make_zip_iterator(col_idxs + start_row_idx, values + start_row_idx); - std::sort(it, it + row_nnz, [](auto t1, auto t2) { - return std::get<0>(t1) < std::get<0>(t2); - }); + std::sort(it, it + row_nnz, + [](auto t1, auto t2) { return get<0>(t1) < get<0>(t2); }); } } diff --git a/reference/matrix/fbcsr_kernels.cpp b/reference/matrix/fbcsr_kernels.cpp index 9e60e380d9c..cdedc36ddc0 100644 --- a/reference/matrix/fbcsr_kernels.cpp +++ b/reference/matrix/fbcsr_kernels.cpp @@ -418,9 +418,8 @@ void sort_by_column_index_impl( std::vector col_permute(nbnz_brow); std::iota(col_permute.begin(), col_permute.end(), 0); auto it = detail::make_zip_iterator(brow_col_idxs, col_permute.data()); - std::sort(it, it + nbnz_brow, [](auto a, auto b) { - return std::get<0>(a) < std::get<0>(b); - }); + std::sort(it, it + nbnz_brow, + [](auto a, auto b) { return get<0>(a) < get<0>(b); }); std::vector oldvalues(nbnz_brow * bs2); std::copy(brow_vals, brow_vals + nbnz_brow * bs2, oldvalues.begin()); diff --git a/reference/multigrid/pgm_kernels.cpp b/reference/multigrid/pgm_kernels.cpp index 2a6e3252a9f..bff2a776c6b 100644 --- a/reference/multigrid/pgm_kernels.cpp +++ b/reference/multigrid/pgm_kernels.cpp @@ -270,8 +270,7 @@ void sort_row_major(std::shared_ptr exec, size_type nnz, { auto it = detail::make_zip_iterator(row_idxs, col_idxs, vals); std::stable_sort(it, it + nnz, [](auto a, auto b) { - return std::tie(std::get<0>(a), std::get<1>(a)) < - std::tie(std::get<0>(b), std::get<1>(b)); + return std::tie(get<0>(a), get<1>(a)) < std::tie(get<0>(b), get<1>(b)); }); } diff --git a/test/base/CMakeLists.txt b/test/base/CMakeLists.txt index d54996f212a..5f31c25db19 100644 --- a/test/base/CMakeLists.txt +++ b/test/base/CMakeLists.txt @@ -1,6 +1,7 @@ ginkgo_create_common_test(batch_multi_vector_kernels) ginkgo_create_common_and_reference_test(device_matrix_data_kernels) ginkgo_create_common_device_test(index_range) +ginkgo_create_common_device_test(iterator_factory) ginkgo_create_common_device_test(kernel_launch_generic) ginkgo_create_common_and_reference_test(executor) ginkgo_create_common_and_reference_test(timer) diff --git a/test/base/iterator_factory.cpp b/test/base/iterator_factory.cpp new file mode 100644 index 00000000000..5dc97646960 --- /dev/null +++ b/test/base/iterator_factory.cpp @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include "core/base/iterator_factory.hpp" + +#include + +#include + +#include + +#include "common/unified/base/kernel_launch.hpp" +#include "core/test/utils.hpp" +#include "test/utils/executor.hpp" + + +class IteratorFactory : public CommonTestFixture { +public: + IteratorFactory() + : key_array{exec, {6, 2, 3, 8, 1, 0, 2}}, + value_array{exec, {9, 5, 7, 2, 4, 7, 2}}, + expected_key_array{ref, {7, 1, 2, 2, 3, 6, 8}}, + expected_value_array{ref, {7, 4, 2, 5, 7, 9, 2}} + {} + + gko::array key_array; + gko::array value_array; + gko::array expected_key_array; + gko::array expected_value_array; +}; + + +// nvcc doesn't like device lambdas declared in complex classes, move it out +void run_zip_iterator(std::shared_ptr exec, + gko::array& key_array, gko::array& value_array) +{ + gko::kernels::EXEC_NAMESPACE::run_kernel( + exec, + [] GKO_KERNEL(auto i, auto keys, auto values, auto size) { + auto begin = gko::detail::make_zip_iterator(keys, values); + auto end = begin + size; + using std::swap; + for (auto it = begin; it != end; ++it) { + auto min_it = it; + for (auto it2 = it; it2 != end; ++it2) { + if (*it2 < *min_it) { + min_it = it2; + } + } + swap(*it, *min_it); + } + // check structured bindings + auto [key, value] = *begin; + static_assert(std::is_same::value, + "incorrect type"); + gko::get<0>(*begin) = value; + }, + 1, key_array, value_array, static_cast(key_array.get_size())); +} + + +TEST_F(IteratorFactory, KernelRunsZipIterator) +{ + run_zip_iterator(exec, key_array, value_array); + + GKO_ASSERT_ARRAY_EQ(key_array, expected_key_array); + GKO_ASSERT_ARRAY_EQ(value_array, expected_value_array); +}