Skip to content

Commit

Permalink
add device support to zip_iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Jul 8, 2024
1 parent caa373d commit 76dd812
Show file tree
Hide file tree
Showing 12 changed files with 410 additions and 88 deletions.
363 changes: 310 additions & 53 deletions core/base/iterator_factory.hpp

Large diffs are not rendered by default.

19 changes: 11 additions & 8 deletions core/test/base/iterator_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ TYPED_TEST(ZipIterator, IteratorReferenceOperatorSmaller2)

TYPED_TEST(ZipIterator, IncreasingIterator)
{
using gko::get;
using index_type = typename TestFixture::index_type;
using value_type = typename TestFixture::value_type;
std::vector<index_type> vec1{this->reversed_index};
Expand All @@ -182,8 +183,8 @@ TYPED_TEST(ZipIterator, IncreasingIterator)
ASSERT_TRUE(increment_pre_2 == increment_post_2);
ASSERT_TRUE(begin == increment_post_test++);
ASSERT_TRUE(begin + 1 == ++increment_pre_test);
ASSERT_TRUE(std::get<0>(*plus_2) == vec1[2]);
ASSERT_TRUE(std::get<1>(*plus_2) == vec2[2]);
ASSERT_TRUE(get<0>(*plus_2) == vec1[2]);
ASSERT_TRUE(get<1>(*plus_2) == vec2[2]);
// check other comparison operators and difference
std::vector<gko::detail::zip_iterator<index_type*, value_type*>> its{
begin,
Expand Down Expand Up @@ -257,6 +258,7 @@ TYPED_TEST(ZipIterator, IncompatibleIteratorDeathTest)

TYPED_TEST(ZipIterator, DecreasingIterator)
{
using gko::get;
using index_type = typename TestFixture::index_type;
using value_type = typename TestFixture::value_type;
std::vector<index_type> vec1{this->reversed_index};
Expand All @@ -280,13 +282,14 @@ TYPED_TEST(ZipIterator, DecreasingIterator)
ASSERT_TRUE(decrement_pre_2 == decrement_post_2);
ASSERT_TRUE(iter == decrement_post_test--);
ASSERT_TRUE(iter - 1 == --decrement_pre_test);
ASSERT_TRUE(std::get<0>(*minus_2) == vec1[3]);
ASSERT_TRUE(std::get<1>(*minus_2) == vec2[3]);
ASSERT_TRUE(get<0>(*minus_2) == vec1[3]);
ASSERT_TRUE(get<1>(*minus_2) == vec2[3]);
}


TYPED_TEST(ZipIterator, CorrectDereferencing)
{
using gko::get;
using index_type_it = typename TestFixture::index_type;
using value_type_it = typename TestFixture::value_type;
std::vector<index_type_it> vec1{this->reversed_index};
Expand All @@ -299,10 +302,10 @@ TYPED_TEST(ZipIterator, CorrectDereferencing)
auto to_test_ref = *(begin + element_to_test);
value_type to_test_pair = to_test_ref; // Testing implicit conversion

ASSERT_TRUE(std::get<0>(to_test_pair) == vec1[element_to_test]);
ASSERT_TRUE(std::get<0>(to_test_pair) == std::get<0>(to_test_ref));
ASSERT_TRUE(std::get<1>(to_test_pair) == vec2[element_to_test]);
ASSERT_TRUE(std::get<1>(to_test_pair) == std::get<1>(to_test_ref));
ASSERT_TRUE(get<0>(to_test_pair) == vec1[element_to_test]);
ASSERT_TRUE(get<0>(to_test_pair) == get<0>(to_test_ref));
ASSERT_TRUE(get<1>(to_test_pair) == vec2[element_to_test]);
ASSERT_TRUE(get<1>(to_test_pair) == get<1>(to_test_ref));
}


Expand Down
7 changes: 3 additions & 4 deletions omp/distributed/partition_helpers_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ void sort_by_range_start(
range_start_ends.get_data() + 1, [](const auto i) { return 2 * i; });
auto sort_it = detail::make_zip_iterator(start_it, end_it, part_ids_d);
// TODO: use TBB or parallel std with c++17
std::stable_sort(sort_it, sort_it + num_parts,
[](const auto& a, const auto& b) {
return std::get<0>(a) < std::get<0>(b);
});
std::stable_sort(
sort_it, sort_it + num_parts,
[](const auto& a, const auto& b) { return get<0>(a) < get<0>(b); });
}

GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(
Expand Down
5 changes: 2 additions & 3 deletions omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1155,9 +1155,8 @@ void sort_by_column_index(std::shared_ptr<const OmpExecutor> exec,
auto row_nnz = row_ptrs[i + 1] - start_row_idx;
auto it = detail::make_zip_iterator(col_idxs + start_row_idx,
values + start_row_idx);
std::sort(it, it + row_nnz, [](auto t1, auto t2) {
return std::get<0>(t1) < std::get<0>(t2);
});
std::sort(it, it + row_nnz,
[](auto t1, auto t2) { return get<0>(t1) < get<0>(t2); });
}
}

Expand Down
5 changes: 2 additions & 3 deletions omp/matrix/fbcsr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,8 @@ void sort_by_column_index_impl(
std::vector<IndexType> col_permute(nbnz_brow);
std::iota(col_permute.begin(), col_permute.end(), 0);
auto it = detail::make_zip_iterator(brow_col_idxs, col_permute.data());
std::sort(it, it + nbnz_brow, [](auto a, auto b) {
return std::get<0>(a) < std::get<0>(b);
});
std::sort(it, it + nbnz_brow,
[](auto a, auto b) { return get<0>(a) < get<0>(b); });

std::vector<ValueType> oldvalues(nbnz_brow * bs2);
std::copy(brow_vals, brow_vals + nbnz_brow * bs2, oldvalues.begin());
Expand Down
3 changes: 1 addition & 2 deletions omp/multigrid/pgm_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ void sort_row_major(std::shared_ptr<const DefaultExecutor> exec, size_type nnz,
{
auto it = detail::make_zip_iterator(row_idxs, col_idxs, vals);
std::stable_sort(it, it + nnz, [](auto a, auto b) {
return std::tie(std::get<0>(a), std::get<1>(a)) <
std::tie(std::get<0>(b), std::get<1>(b));
return std::tie(get<0>(a), get<1>(a)) < std::tie(get<0>(b), get<1>(b));
});
}

Expand Down
13 changes: 6 additions & 7 deletions reference/distributed/partition_helpers_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ void sort_by_range_start(
auto end_it = detail::make_permute_iterator(
range_start_ends.get_data() + 1, [](const auto i) { return 2 * i; });
auto sort_it = detail::make_zip_iterator(start_it, end_it, part_ids_d);
std::stable_sort(sort_it, sort_it + num_parts,
[](const auto& a, const auto& b) {
return std::get<0>(a) < std::get<0>(b);
});
std::stable_sort(
sort_it, sort_it + num_parts,
[](const auto& a, const auto& b) { return get<0>(a) < get<0>(b); });
}

GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(
Expand All @@ -51,9 +50,9 @@ void check_consecutive_ranges(std::shared_ptr<const DefaultExecutor> exec,
auto range_it = detail::make_zip_iterator(start_it, end_it);

if (num_parts) {
result = std::all_of(
range_it, range_it + num_parts - 1,
[](const auto& r) { return std::get<0>(r) == std::get<1>(r); });
result =
std::all_of(range_it, range_it + num_parts - 1,
[](const auto& r) { return get<0>(r) == get<1>(r); });
} else {
result = true;
}
Expand Down
5 changes: 2 additions & 3 deletions reference/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1128,9 +1128,8 @@ void sort_by_column_index(std::shared_ptr<const ReferenceExecutor> exec,
auto row_nnz = row_ptrs[i + 1] - start_row_idx;
auto it = detail::make_zip_iterator(col_idxs + start_row_idx,
values + start_row_idx);
std::sort(it, it + row_nnz, [](auto t1, auto t2) {
return std::get<0>(t1) < std::get<0>(t2);
});
std::sort(it, it + row_nnz,
[](auto t1, auto t2) { return get<0>(t1) < get<0>(t2); });
}
}

Expand Down
5 changes: 2 additions & 3 deletions reference/matrix/fbcsr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,8 @@ void sort_by_column_index_impl(
std::vector<IndexType> col_permute(nbnz_brow);
std::iota(col_permute.begin(), col_permute.end(), 0);
auto it = detail::make_zip_iterator(brow_col_idxs, col_permute.data());
std::sort(it, it + nbnz_brow, [](auto a, auto b) {
return std::get<0>(a) < std::get<0>(b);
});
std::sort(it, it + nbnz_brow,
[](auto a, auto b) { return get<0>(a) < get<0>(b); });

std::vector<ValueType> oldvalues(nbnz_brow * bs2);
std::copy(brow_vals, brow_vals + nbnz_brow * bs2, oldvalues.begin());
Expand Down
3 changes: 1 addition & 2 deletions reference/multigrid/pgm_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,7 @@ void sort_row_major(std::shared_ptr<const DefaultExecutor> exec, size_type nnz,
{
auto it = detail::make_zip_iterator(row_idxs, col_idxs, vals);
std::stable_sort(it, it + nnz, [](auto a, auto b) {
return std::tie(std::get<0>(a), std::get<1>(a)) <
std::tie(std::get<0>(b), std::get<1>(b));
return std::tie(get<0>(a), get<1>(a)) < std::tie(get<0>(b), get<1>(b));
});
}

Expand Down
1 change: 1 addition & 0 deletions test/base/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
ginkgo_create_common_test(batch_multi_vector_kernels)
ginkgo_create_common_and_reference_test(device_matrix_data_kernels)
ginkgo_create_common_device_test(index_range)
ginkgo_create_common_device_test(iterator_factory)
ginkgo_create_common_device_test(kernel_launch_generic)
ginkgo_create_common_and_reference_test(executor)
ginkgo_create_common_and_reference_test(timer)
69 changes: 69 additions & 0 deletions test/base/iterator_factory.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "core/base/iterator_factory.hpp"

#include <memory>

#include <gtest/gtest.h>

#include <ginkgo/core/base/array.hpp>

#include "common/unified/base/kernel_launch.hpp"
#include "core/test/utils.hpp"
#include "test/utils/executor.hpp"


class IteratorFactory : public CommonTestFixture {
public:
IteratorFactory()
: key_array{exec, {6, 2, 3, 8, 1, 0, 2}},
value_array{exec, {9, 5, 7, 2, 4, 7, 2}},
expected_key_array{ref, {7, 1, 2, 2, 3, 6, 8}},
expected_value_array{ref, {7, 4, 2, 5, 7, 9, 2}}
{}

gko::array<int> key_array;
gko::array<int> value_array;
gko::array<int> expected_key_array;
gko::array<int> expected_value_array;
};


// nvcc doesn't like device lambdas declared in complex classes, move it out
void run_zip_iterator(std::shared_ptr<gko::EXEC_TYPE> exec,
gko::array<int>& key_array, gko::array<int>& value_array)
{
gko::kernels::EXEC_NAMESPACE::run_kernel(
exec,
[] GKO_KERNEL(auto i, auto keys, auto values, auto size) {
auto begin = gko::detail::make_zip_iterator(keys, values);
auto end = begin + size;
using std::swap;
for (auto it = begin; it != end; ++it) {
auto min_it = it;
for (auto it2 = it; it2 != end; ++it2) {
if (*it2 < *min_it) {
min_it = it2;
}
}
swap(*it, *min_it);
}
// check structured bindings
auto [key, value] = *begin;
static_assert(std::is_same<typeof(key), int>::value,
"incorrect type");
gko::get<0>(*begin) = value;
},
1, key_array, value_array, static_cast<int>(key_array.get_size()));
}


TEST_F(IteratorFactory, KernelRunsZipIterator)
{
run_zip_iterator(exec, key_array, value_array);

GKO_ASSERT_ARRAY_EQ(key_array, expected_key_array);
GKO_ASSERT_ARRAY_EQ(value_array, expected_value_array);
}

0 comments on commit 76dd812

Please sign in to comment.