diff --git a/common/cuda_hip/distributed/index_map_kernels.hpp.inc b/common/cuda_hip/distributed/index_map_kernels.hpp.inc new file mode 100644 index 00000000000..9d312cc43aa --- /dev/null +++ b/common/cuda_hip/distributed/index_map_kernels.hpp.inc @@ -0,0 +1,268 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +/** + * This struct is necessary, since the `transform_output_iterator` seemingly + * doesn't support non-copyable tranfsorm function (this excludes lambdas) + */ +template +struct transform_output { + transform_output(const GlobalIndexType* range_bounds_, + const LocalIndexType* range_starting_indices_) + : range_bounds(range_bounds_), + range_starting_indices(range_starting_indices_) + {} + + template + __host__ __device__ thrust::tuple + operator()(const T& t) + { + auto gid = thrust::get<0>(t); + auto rid = thrust::get<1>(t); + return thrust::make_tuple(gid, map_to_local(gid, rid)); + } + + __host__ __device__ LocalIndexType map_to_local(const GlobalIndexType index, + const size_type range_id) + { + return static_cast(index - range_bounds[range_id]) + + range_starting_indices[range_id]; + }; + + const GlobalIndexType* range_bounds; + const LocalIndexType* range_starting_indices; +}; + + +template +array compute_range_ids( + std::shared_ptr exec, + const experimental::distributed::Partition* + part, + const array& idxs) +{ + const auto* range_bounds = part->get_range_bounds(); + const auto num_ranges = part->get_num_ranges(); + auto input_size = idxs.get_size(); + auto idxs_ptr = idxs.get_const_data(); + + auto policy = thrust_policy(exec); + + array range_ids{exec, input_size}; + thrust::upper_bound(policy, range_bounds + 1, range_bounds + num_ranges + 1, + idxs_ptr, idxs_ptr + input_size, range_ids.get_data()); + return range_ids; +} + + +template +void build_mapping( + std::shared_ptr exec, + const experimental::distributed::Partition* + part, + const array& recv_connections, + array& remote_part_ids, + array& remote_local_idxs, + array& remote_global_idxs, array& remote_sizes) +{ + auto part_ids = part->get_part_ids(); + auto num_parts = static_cast(part->get_num_parts()); + const auto* range_bounds = part->get_range_bounds(); + const auto* range_starting_indices = part->get_range_starting_indices(); + const auto num_ranges = part->get_num_ranges(); + auto input_size = recv_connections.get_size(); + + auto recv_connections_copy = recv_connections; + auto recv_connections_ptr = recv_connections_copy.get_data(); + + auto policy = thrust_policy(exec); + + // precompute the range id of each input element + auto range_ids = compute_range_ids(exec, part, recv_connections_copy); + auto it_range_ids = range_ids.get_data(); + + // map input to owning part-id + array full_remote_part_ids( + exec, input_size); + auto it_full_remote_part_ids = full_remote_part_ids.get_data(); + thrust::transform(policy, it_range_ids, it_range_ids + input_size, + it_full_remote_part_ids, + [part_ids] __host__ __device__(const size_type rid) { + return part_ids[rid]; + }); + + // sort by part-id and recv_connection + auto sort_it = thrust::make_zip_iterator( + thrust::make_tuple(it_full_remote_part_ids, recv_connections_ptr)); + thrust::sort_by_key(policy, sort_it, sort_it + input_size, it_range_ids); + + auto unique_end = thrust::unique_by_key(policy, sort_it, + sort_it + input_size, it_range_ids); + auto unique_range_id_end = unique_end.second; + auto unique_size = thrust::distance(it_range_ids, unique_range_id_end); + + remote_global_idxs.resize_and_reset(unique_size); + remote_local_idxs.resize_and_reset(unique_size); + + // store unique connections, also map global indices to local + { + auto copy_it = thrust::make_zip_iterator( + thrust::make_tuple(recv_connections_ptr, it_range_ids)); + thrust::copy_n(policy, copy_it, unique_size, + thrust::make_transform_output_iterator( + thrust::make_zip_iterator(thrust::make_tuple( + remote_global_idxs.get_data(), + remote_local_idxs.get_data())), + transform_output{ + range_bounds, range_starting_indices})); + } + + // compute number of connections per part-id + array full_remote_sizes(exec, + part->get_num_parts()); + auto recv_sizes_ptr = full_remote_sizes.get_data(); + thrust::fill_n(policy, recv_sizes_ptr, num_parts, 0); + thrust::for_each_n(policy, it_full_remote_part_ids, unique_size, + [recv_sizes_ptr] __device__(const size_type part) { + atomic_add(recv_sizes_ptr + part, 1); + }); + + auto is_neighbor = [] __host__ __device__(const size_type s) { + return s != 0; + }; + auto num_neighbors = + thrust::count_if(policy, recv_sizes_ptr, + recv_sizes_ptr + part->get_num_parts(), is_neighbor); + + remote_part_ids.resize_and_reset(num_neighbors); + + remote_sizes.resize_and_reset(num_neighbors); + { + auto counting_it = thrust::make_counting_iterator(0); + auto copy_it = thrust::make_zip_iterator( + thrust::make_tuple(recv_sizes_ptr, counting_it)); + thrust::copy_if( + policy, copy_it, copy_it + part->get_num_parts(), + thrust::make_zip_iterator(thrust::make_tuple( + remote_sizes.get_data(), remote_part_ids.get_data())), + [] __host__ __device__( + const thrust::tuple& t) { + return thrust::get<0>(t) > 0; + }); + } +} + +GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( + GKO_DECLARE_INDEX_MAP_BUILD_MAPPING); + + +template +void map_to_local( + std::shared_ptr exec, + const experimental::distributed::Partition* + partition, + const array& remote_target_ids, + device_segmented_array remote_global_idxs, + experimental::distributed::comm_index_type rank, + const array& global_ids, + experimental::distributed::index_space is, array& local_ids) +{ + auto part_ids = partition->get_part_ids(); + auto part_sizes = partition->get_part_sizes(); + auto num_parts = static_cast(partition->get_num_parts()); + const auto* range_bounds = partition->get_range_bounds(); + const auto* range_starting_indices = + partition->get_range_starting_indices(); + const auto num_ranges = partition->get_num_ranges(); + auto input_size = global_ids.get_size(); + auto global_ids_it = global_ids.get_const_data(); + + auto policy = thrust_policy(exec); + + local_ids.resize_and_reset(input_size); + auto local_ids_it = local_ids.get_data(); + + auto range_ids = compute_range_ids(exec, partition, global_ids); + auto range_ids_it = range_ids.get_const_data(); + + auto map_local = + [rank, part_ids, range_bounds, range_starting_indices] __device__( + const thrust::tuple& t) { + auto gid = thrust::get<0>(t); + auto rid = thrust::get<1>(t); + auto pid = part_ids[rid]; + return pid == rank + ? static_cast(gid - range_bounds[rid]) + + range_starting_indices[rid] + : invalid_index(); + }; + + auto remote_target_ids_ptr = remote_target_ids.get_const_data(); + auto num_target_ids = remote_target_ids.get_size(); + auto remote_global_idxs_ptr = remote_global_idxs.flat_begin; + auto offsets_ptr = remote_global_idxs.offsets_begin; + auto map_non_local = + [num_target_ids, remote_target_ids_ptr, part_ids, offsets_ptr, + remote_global_idxs_ptr] __device__(const thrust::tuple& t) { + auto gid = thrust::get<0>(t); + auto rid = thrust::get<1>(t); + auto pid = part_ids[rid]; + auto set_id = binary_search( + size_type{0}, num_target_ids, + [=](const auto i) { return remote_target_ids_ptr[i] >= pid; }); + + // Set an invalid index, if the part-id could not be found + if (set_id == num_target_ids) { + return invalid_index(); + } + + // need to check if *it is actually the current global-id + // since the global-id might not be registered as connected + // to this rank + auto it = binary_search( + offsets_ptr[set_id], + offsets_ptr[set_id + 1] - offsets_ptr[set_id], + [=](const auto i) { return remote_global_idxs_ptr[i] >= gid; }); + return it != offsets_ptr[set_id + 1] && + remote_global_idxs_ptr[it] == gid + ? static_cast(it) + : invalid_index(); + }; + + auto map_combined = + [part_ids, rank, map_local, map_non_local, part_sizes] __device__( + const thrust::tuple& t) { + auto gid = thrust::get<0>(t); + auto rid = thrust::get<1>(t); + auto pid = part_ids[rid]; + + if (pid == rank) { + return map_local(t); + } else { + auto id = map_non_local(t); + return id == invalid_index() + ? id + : id + part_sizes[rank]; + } + }; + + auto transform_it = thrust::make_zip_iterator( + thrust::make_tuple(global_ids_it, range_ids_it)); + if (is == experimental::distributed::index_space::local) { + thrust::transform(policy, transform_it, transform_it + input_size, + local_ids_it, map_local); + } + if (is == experimental::distributed::index_space::non_local) { + thrust::transform(policy, transform_it, transform_it + input_size, + local_ids_it, map_non_local); + } + if (is == experimental::distributed::index_space::combined) { + thrust::transform(policy, transform_it, transform_it + input_size, + local_ids_it, map_combined); + } +} + +GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( + GKO_DECLARE_INDEX_MAP_MAP_TO_LOCAL); diff --git a/common/unified/CMakeLists.txt b/common/unified/CMakeLists.txt index d1387e8f8bf..00bc21df0c6 100644 --- a/common/unified/CMakeLists.txt +++ b/common/unified/CMakeLists.txt @@ -6,7 +6,6 @@ set(UNIFIED_SOURCES components/format_conversion_kernels.cpp components/precision_conversion_kernels.cpp components/reduce_array_kernels.cpp - distributed/index_map_kernels.cpp distributed/partition_helpers_kernels.cpp distributed/partition_kernels.cpp matrix/coo_kernels.cpp diff --git a/cuda/CMakeLists.txt b/cuda/CMakeLists.txt index ace6e61056d..bd214691a2e 100644 --- a/cuda/CMakeLists.txt +++ b/cuda/CMakeLists.txt @@ -20,6 +20,7 @@ target_sources(ginkgo_cuda base/timer.cpp base/version.cpp components/prefix_sum_kernels.cu + distributed/index_map_kernels.cu distributed/matrix_kernels.cu distributed/partition_helpers_kernels.cu distributed/partition_kernels.cu diff --git a/cuda/distributed/index_map_kernels.cu b/cuda/distributed/index_map_kernels.cu new file mode 100644 index 00000000000..a5d838e901f --- /dev/null +++ b/cuda/distributed/index_map_kernels.cu @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include "core/distributed/index_map_kernels.hpp" + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#include + + +#include "cuda/base/thrust.cuh" +#include "cuda/components/atomic.cuh" +#include "cuda/components/searching.cuh" + + +namespace gko { +namespace kernels { +namespace cuda { +namespace index_map { + + +#include "common/cuda_hip/distributed/index_map_kernels.hpp.inc" + + +} // namespace index_map +} // namespace cuda +} // namespace kernels +} // namespace gko diff --git a/dpcpp/CMakeLists.txt b/dpcpp/CMakeLists.txt index e865dd703e3..8c68efae046 100644 --- a/dpcpp/CMakeLists.txt +++ b/dpcpp/CMakeLists.txt @@ -18,6 +18,7 @@ target_sources(ginkgo_dpcpp base/timer.dp.cpp base/version.dp.cpp components/prefix_sum_kernels.dp.cpp + distributed/index_map_kernels.dp.cpp distributed/matrix_kernels.dp.cpp distributed/partition_helpers_kernels.dp.cpp distributed/partition_kernels.dp.cpp diff --git a/common/unified/distributed/index_map_kernels.cpp b/dpcpp/distributed/index_map_kernels.dp.cpp similarity index 100% rename from common/unified/distributed/index_map_kernels.cpp rename to dpcpp/distributed/index_map_kernels.dp.cpp diff --git a/hip/CMakeLists.txt b/hip/CMakeLists.txt index 148aa296406..046fd1e4d7a 100644 --- a/hip/CMakeLists.txt +++ b/hip/CMakeLists.txt @@ -18,6 +18,7 @@ set(GINKGO_HIP_SOURCES base/timer.hip.cpp base/version.hip.cpp components/prefix_sum_kernels.hip.cpp + distributed/index_map_kernels.hip.cpp distributed/matrix_kernels.hip.cpp distributed/partition_helpers_kernels.hip.cpp distributed/partition_kernels.hip.cpp diff --git a/hip/distributed/index_map_kernels.hip.cpp b/hip/distributed/index_map_kernels.hip.cpp new file mode 100644 index 00000000000..d45674a66a3 --- /dev/null +++ b/hip/distributed/index_map_kernels.hip.cpp @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include "core/distributed/index_map_kernels.hpp" + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#include + + +#include "hip/base/thrust.hip.hpp" +#include "hip/components/atomic.hip.hpp" +#include "hip/components/searching.hip.hpp" + + +namespace gko { +namespace kernels { +namespace hip { +namespace index_map { + + +#include "common/cuda_hip/distributed/index_map_kernels.hpp.inc" + + +} // namespace index_map +} // namespace hip +} // namespace kernels +} // namespace gko diff --git a/omp/CMakeLists.txt b/omp/CMakeLists.txt index 263416fc21c..59d49e44140 100644 --- a/omp/CMakeLists.txt +++ b/omp/CMakeLists.txt @@ -10,6 +10,7 @@ target_sources(ginkgo_omp base/scoped_device_id.cpp base/version.cpp components/prefix_sum_kernels.cpp + distributed/index_map_kernels.cpp distributed/matrix_kernels.cpp distributed/partition_helpers_kernels.cpp distributed/partition_kernels.cpp diff --git a/omp/distributed/index_map_kernels.cpp b/omp/distributed/index_map_kernels.cpp new file mode 100644 index 00000000000..02ae63261a0 --- /dev/null +++ b/omp/distributed/index_map_kernels.cpp @@ -0,0 +1,249 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include "core/distributed/index_map_kernels.hpp" + + +#include + + +#include + + +#include "core/base/allocator.hpp" +#include "core/base/device_matrix_data_kernels.hpp" +#include "core/base/iterator_factory.hpp" +#include "core/base/segmented_array.hpp" +#include "reference/distributed/partition_helpers.hpp" + + +namespace gko { +namespace kernels { +namespace omp { +namespace index_map { + + +template +void build_mapping( + std::shared_ptr exec, + const experimental::distributed::Partition* + part, + const array& recv_connections, + array& remote_part_ids, + array& remote_local_idxs, + array& remote_global_idxs, array& remote_sizes) +{ + using experimental::distributed::comm_index_type; + using partition_type = + experimental::distributed::Partition; + auto part_ids = part->get_part_ids(); + const auto* range_bounds = part->get_range_bounds(); + const auto* range_starting_indices = part->get_range_starting_indices(); + const auto num_ranges = part->get_num_ranges(); + auto input_size = recv_connections.get_size(); + + auto recv_connections_copy = recv_connections; + auto recv_connections_ptr = recv_connections_copy.get_data(); + + // precompute the range id and part id of each input element + vector range_ids(input_size, exec); + vector full_remote_part_ids(input_size, exec); + size_type range_id = 0; +#pragma omp parallel for firstprivate(range_id) + for (size_type i = 0; i < input_size; ++i) { + range_id = find_range(recv_connections_ptr[i], part, range_id); + range_ids[i] = range_id; + full_remote_part_ids[i] = part_ids[range_ids[i]]; + } + + // sort by part-id and recv_connection + auto sort_it = detail::make_zip_iterator( + full_remote_part_ids.begin(), recv_connections_ptr, range_ids.begin()); + std::sort(sort_it, sort_it + input_size, [](const auto& a, const auto& b) { + return std::tie(std::get<0>(a), std::get<1>(a)) < + std::tie(std::get<0>(b), std::get<1>(b)); + }); + + // get only unique connections + auto unique_end = std::unique( + sort_it, sort_it + input_size, [](const auto& a, const auto& b) { + return std::tie(std::get<0>(a), std::get<1>(a)) == + std::tie(std::get<0>(b), std::get<1>(b)); + }); + auto unique_size = std::distance(sort_it, unique_end); + + remote_global_idxs.resize_and_reset(unique_size); + auto remote_global_idxs_ptr = remote_global_idxs.get_data(); + remote_local_idxs.resize_and_reset(unique_size); + auto remote_local_idxs_ptr = remote_local_idxs.get_data(); + + // store unique connections, also map global indices to local +#pragma omp parallel for + for (size_type i = 0; i < unique_size; ++i) { + remote_global_idxs_ptr[i] = recv_connections_ptr[i]; + remote_local_idxs_ptr[i] = + map_to_local(recv_connections_ptr[i], part, range_ids[i]); + } + + // compute number of connections per part-id + vector full_remote_sizes(part->get_num_parts(), 0, + exec); + +#pragma omp parallel for + for (size_type i = 0; i < unique_size; ++i) { + // std::vector access with [] can count as function call, which + // is not allowed in an atomic expression, thus getting the reference + // before the atomic update. + auto& size = full_remote_sizes[full_remote_part_ids[i]]; +#pragma omp atomic + size++; + } + auto num_neighbors = + full_remote_sizes.size() - + std::count(full_remote_sizes.begin(), full_remote_sizes.end(), 0); + + remote_sizes.resize_and_reset(num_neighbors); + remote_part_ids.resize_and_reset(num_neighbors); + { + size_type idx = 0; + for (size_type i = 0; i < full_remote_sizes.size(); ++i) { + if (full_remote_sizes[i] > 0) { + remote_part_ids.get_data()[idx] = + static_cast(i); + remote_sizes.get_data()[idx] = + static_cast(full_remote_sizes[i]); + idx++; + } + } + } +} + +GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( + GKO_DECLARE_INDEX_MAP_BUILD_MAPPING); + + +template +void map_to_local( + std::shared_ptr exec, + const experimental::distributed::Partition* + partition, + const array& remote_target_ids, + device_segmented_array remote_global_idxs, + experimental::distributed::comm_index_type rank, + const array& global_ids, + experimental::distributed::index_space is, array& local_ids) +{ + auto part_ids = partition->get_part_ids(); + auto range_bounds = partition->get_range_bounds(); + auto range_starting_idxs = partition->get_range_starting_indices(); + + local_ids.resize_and_reset(global_ids.get_size()); + + // can't extract functions to map global indices to local indices as for + // the reference implementation, because it resulted in internal compiler + // errors for intel 19.1.3 + size_type range_id = 0; + if (is == experimental::distributed::index_space::local) { +#pragma omp parallel for firstprivate(range_id) + for (size_type i = 0; i < global_ids.get_size(); ++i) { + auto gid = global_ids.get_const_data()[i]; + + range_id = find_range(gid, partition, range_id); + auto part_id = part_ids[range_id]; + + local_ids.get_data()[i] = part_id == rank + ? static_cast( + gid - range_bounds[range_id]) + + range_starting_idxs[range_id] + : invalid_index(); + } + } + if (is == experimental::distributed::index_space::non_local) { +#pragma omp parallel for firstprivate(range_id) + for (size_type i = 0; i < global_ids.get_size(); ++i) { + auto gid = global_ids.get_const_data()[i]; + + range_id = find_range(gid, partition, range_id); + auto part_id = part_ids[range_id]; + + // can't do binary search on whole remote_target_idxs array, + // since the array is first sorted by part-id and then by + // global index. As a result, the array is not sorted wrt. + // the global indexing. So find the part-id that corresponds + // to the global index first + auto set_id = std::distance( + remote_target_ids.get_const_data(), + std::lower_bound(remote_target_ids.get_const_data(), + remote_target_ids.get_const_data() + + remote_target_ids.get_size(), + part_id)); + + if (set_id == remote_target_ids.get_size()) { + local_ids.get_data()[i] = invalid_index(); + } else { + auto segment = remote_global_idxs.get_segment(set_id); + + // need to check if *it is actually the current global-id + // since the global-id might not be registered as connected + // to this rank + auto it = std::lower_bound(segment.begin, segment.end, gid); + local_ids.get_data()[i] = + it != segment.end && *it == gid + ? static_cast( + std::distance(remote_global_idxs.flat_begin, it)) + : invalid_index(); + } + } + } + if (is == experimental::distributed::index_space::combined) { + auto offset = partition->get_part_sizes()[rank]; +#pragma omp parallel for firstprivate(range_id) default(shared) + for (size_type i = 0; i < global_ids.get_size(); ++i) { + auto gid = global_ids.get_const_data()[i]; + range_id = find_range(gid, partition, range_id); + auto part_id = part_ids[range_id]; + + if (part_id == rank) { + // same as is local + local_ids.get_data()[i] = + part_id == rank ? static_cast( + gid - range_bounds[range_id]) + + range_starting_idxs[range_id] + : invalid_index(); + } else { + // same as is non_local, with additional offset + auto set_id = std::distance( + remote_target_ids.get_const_data(), + std::lower_bound(remote_target_ids.get_const_data(), + remote_target_ids.get_const_data() + + remote_target_ids.get_size(), + part_id)); + + if (set_id == remote_target_ids.get_size()) { + local_ids.get_data()[i] = invalid_index(); + } else { + auto segment = remote_global_idxs.get_segment(set_id); + + auto it = std::lower_bound(segment.begin, segment.end, gid); + local_ids.get_data()[i] = + it != segment.end && *it == gid + ? static_cast( + std::distance(remote_global_idxs.flat_begin, + it) + + offset) + : invalid_index(); + } + } + } + } +} + +GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( + GKO_DECLARE_INDEX_MAP_MAP_TO_LOCAL); + + +} // namespace index_map +} // namespace omp +} // namespace kernels +} // namespace gko diff --git a/test/distributed/CMakeLists.txt b/test/distributed/CMakeLists.txt index 32b3810ea31..9e0c875de0e 100644 --- a/test/distributed/CMakeLists.txt +++ b/test/distributed/CMakeLists.txt @@ -1,3 +1,4 @@ +ginkgo_create_common_test(index_map_kernels DISABLE_EXECUTORS dpcpp) ginkgo_create_common_test(matrix_kernels DISABLE_EXECUTORS dpcpp) ginkgo_create_common_test(partition_kernels DISABLE_EXECUTORS dpcpp) ginkgo_create_common_test(vector_kernels DISABLE_EXECUTORS dpcpp) diff --git a/test/distributed/index_map_kernels.cpp b/test/distributed/index_map_kernels.cpp new file mode 100644 index 00000000000..458ca594a56 --- /dev/null +++ b/test/distributed/index_map_kernels.cpp @@ -0,0 +1,394 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include "core/distributed/index_map_kernels.hpp" + + +#include +#include + + +#include +#include + + +#include +#include +#include +#include +#include + + +#include "core/distributed/partition_kernels.hpp" +#include "core/test/utils.hpp" +#include "test/utils/executor.hpp" + + +using comm_index_type = gko::experimental::distributed::comm_index_type; + + +template +gko::array generate_connection_idxs( + const std::shared_ptr& exec, comm_index_type rank, + std::shared_ptr> + partition, + std::default_random_engine engine, gko::size_type num_connections) +{ + auto ref = exec->get_master(); + auto num_parts = partition->get_num_parts(); + auto local_size = + static_cast(partition->get_part_size(rank)); + // create vector with [0, ..., num_parts) excluding excluded_pid + std::vector part_ids(num_parts - 1); + std::iota(part_ids.begin(), part_ids.end(), rank + 1); + std::transform(part_ids.begin(), part_ids.end(), part_ids.begin(), + [&](const auto pid) { return pid % num_parts; }); + // get random connections + std::shuffle(part_ids.begin(), part_ids.end(), engine); + std::vector connected_ids( + part_ids.begin(), part_ids.begin() + num_connections); + // create global index space of connections + std::vector connections_index_space; + for (auto pid : connected_ids) { + for (GlobalIndexType i = 0; i < local_size; ++i) { + connections_index_space.push_back( + i + static_cast(pid * local_size)); + } + } + // generate query from connection_index_space + std::uniform_int_distribution<> dist(0, connections_index_space.size() - 1); + gko::array connection_idxs{ref, 11}; + std::generate_n(connection_idxs.get_data(), connection_idxs.get_size(), + [&] { return connections_index_space[dist(engine)]; }); + return {exec, std::move(connection_idxs)}; +} + + +class IndexMapBuildMapping : public CommonTestFixture {}; + + +TEST_F(IndexMapBuildMapping, BuildMappingSameAsRef) +{ + using local_index_type = gko::int32; + using global_index_type = gko::int64; + using part_type = + gko::experimental::distributed::Partition; + std::default_random_engine engine; + comm_index_type num_parts = 13; + global_index_type local_size = 41; + comm_index_type this_rank = 5; + std::shared_ptr part = part_type::build_from_global_size_uniform( + ref, num_parts, num_parts * local_size); + std::shared_ptr dpart = gko::clone(exec, part); + auto query = generate_connection_idxs(ref, this_rank, part, engine, 11); + auto dquery = gko::array(exec, query); + gko::array target_ids{ref}; + gko::array remote_local_idxs{ref}; + gko::array remote_global_idxs{ref}; + gko::array remote_sizes{ref}; + gko::array dtarget_ids{exec}; + gko::array dremote_local_idxs{exec}; + gko::array dremote_global_idxs{exec}; + gko::array dremote_sizes{exec}; + + gko::kernels::reference::index_map::build_mapping( + ref, part.get(), query, target_ids, remote_local_idxs, + remote_global_idxs, remote_sizes); + gko::kernels::EXEC_NAMESPACE::index_map::build_mapping( + exec, dpart.get(), dquery, dtarget_ids, dremote_local_idxs, + dremote_global_idxs, dremote_sizes); + + GKO_ASSERT_ARRAY_EQ(remote_sizes, dremote_sizes); + GKO_ASSERT_ARRAY_EQ(target_ids, dtarget_ids); + GKO_ASSERT_ARRAY_EQ(remote_local_idxs, dremote_local_idxs); + GKO_ASSERT_ARRAY_EQ(remote_global_idxs, dremote_global_idxs); +} + + +class IndexMap : public CommonTestFixture { +protected: + using local_index_type = gko::int32; + using global_index_type = gko::int64; + using part_type = + gko::experimental::distributed::Partition; + using map_type = + gko::experimental::distributed::index_map; + + IndexMap() + { + auto connections = + generate_connection_idxs(ref, this_rank, part, engine, 11); + auto dconnections = gko::array(exec, connections); + + auto flat_remote_local_idxs = gko::array(ref); + auto flat_remote_global_idxs = gko::array(ref); + auto dflat_remote_local_idxs = gko::array(exec); + auto dflat_remote_global_idxs = gko::array(exec); + + auto remote_sizes = gko::array(ref); + auto dremote_sizes = gko::array(exec); + + gko::kernels::reference::index_map::build_mapping( + ref, part.get(), connections, target_ids, flat_remote_local_idxs, + flat_remote_global_idxs, remote_sizes); + gko::kernels::EXEC_NAMESPACE::index_map::build_mapping( + exec, dpart.get(), dconnections, dtarget_ids, + dflat_remote_local_idxs, dflat_remote_global_idxs, dremote_sizes); + + remote_local_idxs = + gko::segmented_array::create_from_sizes( + std::move(flat_remote_local_idxs), remote_sizes); + remote_global_idxs = + gko::segmented_array::create_from_sizes( + std::move(flat_remote_global_idxs), remote_sizes); + dremote_local_idxs = + gko::segmented_array::create_from_sizes( + std::move(dflat_remote_local_idxs), dremote_sizes); + dremote_global_idxs = + gko::segmented_array::create_from_sizes( + std::move(dflat_remote_global_idxs), dremote_sizes); + } + + gko::array generate_query( + std::shared_ptr exec, + const gko::array& connection_idxs, + gko::size_type num_queries) + { + auto host_connection_idxs = + gko::make_temporary_clone(ref, &connection_idxs); + // generate query from connection_index_space + std::uniform_int_distribution<> dist(0, connection_idxs.get_size() - 1); + gko::array query{ref, num_queries}; + std::generate_n(query.get_data(), query.get_size(), [&] { + return host_connection_idxs->get_const_data()[dist(engine)]; + }); + return {std::move(exec), std::move(query)}; + } + + gko::array generate_complement_idxs( + std::shared_ptr exec, + const gko::array& idxs) + { + auto host_idxs = gko::make_temporary_clone(ref, &idxs); + std::vector full_idxs(part->get_size()); + std::iota(full_idxs.begin(), full_idxs.end(), 0); + + std::set idxs_set( + host_idxs->get_const_data(), + host_idxs->get_const_data() + host_idxs->get_size()); + + auto end = std::remove_if( + full_idxs.begin(), full_idxs.end(), + [&](const auto v) { return idxs_set.find(v) != idxs_set.end(); }); + auto complement_size = std::distance(full_idxs.begin(), end); + return {std::move(exec), full_idxs.begin(), end}; + } + + + gko::array combine_arrays( + std::shared_ptr exec, + const gko::array& a, + const gko::array& b) + { + gko::array result(exec, a.get_size() + b.get_size()); + exec->copy_from(a.get_executor(), a.get_size(), a.get_const_data(), + result.get_data()); + exec->copy_from(b.get_executor(), b.get_size(), b.get_const_data(), + result.get_data() + a.get_size()); + return result; + } + + gko::array take_random( + const gko::array& a, gko::size_type n) + { + auto copy = gko::array(ref, a); + std::shuffle(copy.get_data(), copy.get_data() + copy.get_size(), + engine); + + return {a.get_executor(), copy.get_const_data(), + copy.get_const_data() + n}; + } + + gko::array target_ids{ref}; + gko::segmented_array remote_local_idxs{ref}; + gko::segmented_array remote_global_idxs{ref}; + gko::array dtarget_ids{exec}; + gko::segmented_array dremote_local_idxs{exec}; + gko::segmented_array dremote_global_idxs{exec}; + + comm_index_type num_parts = 13; + global_index_type local_size = 41; + comm_index_type this_rank = 5; + + std::shared_ptr part = part_type::build_from_global_size_uniform( + ref, num_parts, num_parts* local_size); + std::shared_ptr dpart = gko::clone(exec, part); + + std::default_random_engine engine; +}; + + +TEST_F(IndexMap, GetLocalWithLocalIndexSpaceSameAsRef) +{ + auto local_space = gko::array(ref, local_size); + std::iota(local_space.get_data(), local_space.get_data() + local_size, + this_rank * local_size); + auto query = generate_query(ref, local_space, 33); + auto dquery = gko::array(exec, query); + auto result = gko::array(ref); + auto dresult = gko::array(exec); + + gko::kernels::reference::index_map::map_to_local( + ref, part.get(), target_ids, to_device_const(remote_global_idxs), + this_rank, query, gko::experimental::distributed::index_space::local, + result); + gko::kernels::EXEC_NAMESPACE::index_map::map_to_local( + exec, dpart.get(), dtarget_ids, to_device_const(dremote_global_idxs), + this_rank, dquery, gko::experimental::distributed::index_space::local, + dresult); + + GKO_ASSERT_ARRAY_EQ(result, dresult); +} + + +TEST_F(IndexMap, GetLocalWithLocalIndexSpaceWithInvalidIndexSameAsRef) +{ + auto local_space = gko::array(ref, local_size); + std::iota(local_space.get_data(), local_space.get_data() + local_size, + this_rank * local_size); + auto query = generate_query( + ref, + combine_arrays( + ref, local_space, + take_random(generate_complement_idxs(ref, local_space), 12)), + 33); + auto dquery = gko::array(exec, query); + auto result = gko::array(ref); + auto dresult = gko::array(exec); + + gko::kernels::reference::index_map::map_to_local( + ref, part.get(), target_ids, to_device_const(remote_global_idxs), + this_rank, query, gko::experimental::distributed::index_space::local, + result); + gko::kernels::EXEC_NAMESPACE::index_map::map_to_local( + exec, dpart.get(), dtarget_ids, to_device_const(dremote_global_idxs), + this_rank, dquery, gko::experimental::distributed::index_space::local, + dresult); + + GKO_ASSERT_ARRAY_EQ(result, dresult); +} + + +template +gko::array get_flat_array(const gko::segmented_array& arr) +{ + return gko::make_const_array_view(arr.get_executor(), arr.get_size(), + arr.get_const_flat_data()) + .copy_to_array(); +} + + +TEST_F(IndexMap, GetLocalWithNonLocalIndexSpaceSameAsRef) +{ + auto query = generate_query(ref, get_flat_array(remote_global_idxs), 33); + auto dquery = gko::array(exec, query); + auto result = gko::array(ref); + auto dresult = gko::array(exec); + + gko::kernels::reference::index_map::map_to_local( + ref, part.get(), target_ids, to_device_const(remote_global_idxs), + this_rank, query, + gko::experimental::distributed::index_space::non_local, result); + gko::kernels::EXEC_NAMESPACE::index_map::map_to_local( + exec, dpart.get(), dtarget_ids, to_device_const(dremote_global_idxs), + this_rank, dquery, + gko::experimental::distributed::index_space::non_local, dresult); + + GKO_ASSERT_ARRAY_EQ(result, dresult); +} + + +TEST_F(IndexMap, GetLocalWithNonLocalIndexSpaceWithInvalidIndexSameAsRef) +{ + auto query = generate_query( + ref, + combine_arrays(ref, get_flat_array(remote_global_idxs), + take_random(generate_complement_idxs( + ref, get_flat_array(remote_global_idxs)), + 12)), + 33); + auto dquery = gko::array(exec, query); + auto result = gko::array(ref); + auto dresult = gko::array(exec); + + gko::kernels::reference::index_map::map_to_local( + ref, part.get(), target_ids, to_device_const(remote_global_idxs), + this_rank, query, + gko::experimental::distributed::index_space::non_local, result); + gko::kernels::EXEC_NAMESPACE::index_map::map_to_local( + exec, dpart.get(), dtarget_ids, to_device_const(dremote_global_idxs), + this_rank, dquery, + gko::experimental::distributed::index_space::non_local, dresult); + + GKO_ASSERT_ARRAY_EQ(result, dresult); +} + + +TEST_F(IndexMap, GetLocalWithCombinedIndexSpaceSameAsRef) +{ + auto local_space = gko::array(ref, local_size); + std::iota(local_space.get_data(), local_space.get_data() + local_size, + this_rank * local_size); + auto combined_space = + combine_arrays(ref, local_space, get_flat_array(remote_global_idxs)); + auto query = generate_query(ref, combined_space, 33); + auto dquery = gko::array(exec, query); + auto result = gko::array(ref); + auto dresult = gko::array(exec); + + gko::kernels::reference::index_map::map_to_local( + ref, part.get(), target_ids, to_device_const(remote_global_idxs), + this_rank, query, gko::experimental::distributed::index_space::combined, + result); + gko::kernels::EXEC_NAMESPACE::index_map::map_to_local( + exec, dpart.get(), dtarget_ids, to_device_const(dremote_global_idxs), + this_rank, dquery, + gko::experimental::distributed::index_space::combined, dresult); + + GKO_ASSERT_ARRAY_EQ(result, dresult); +} + + +TEST_F(IndexMap, GetLocalWithCombinedIndexSpaceWithInvalidIndexSameAsRef) +{ + auto local_space = gko::array(ref, local_size); + std::iota(local_space.get_data(), local_space.get_data() + local_size, + this_rank * local_size); + auto combined_space = + combine_arrays(ref, local_space, get_flat_array(remote_global_idxs)); + auto query = generate_query( + ref, + combine_arrays( + ref, combined_space, + take_random(generate_complement_idxs(ref, combined_space), 12)), + 33); + auto dquery = gko::array(exec, query); + auto result = gko::array(ref); + auto dresult = gko::array(exec); + + gko::kernels::reference::index_map::map_to_local( + ref, part.get(), target_ids, to_device_const(remote_global_idxs), + this_rank, query, + gko::experimental::distributed::index_space::non_local, result); + gko::kernels::EXEC_NAMESPACE::index_map::map_to_local( + exec, dpart.get(), dtarget_ids, to_device_const(dremote_global_idxs), + this_rank, dquery, + gko::experimental::distributed::index_space::non_local, dresult); + + GKO_ASSERT_ARRAY_EQ(result, dresult); +}