Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Index map in PGM #1639

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions benchmark/test/reference/distributed_solver.profile.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,28 @@ DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin partition::build_from_contiguous
DEBUG: end partition::build_from_contiguous
DEBUG: begin partition::build_starting_indices
DEBUG: end partition::build_starting_indices
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin partition::build_ranges_by_part
DEBUG: end partition::build_ranges_by_part
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::prefix_sum_nonnegative
DEBUG: end components::prefix_sum_nonnegative
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
Expand Down Expand Up @@ -82,8 +98,6 @@ DEBUG: begin copy
DEBUG: end copy
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::convert_idxs_to_ptrs
DEBUG: end components::convert_idxs_to_ptrs
DEBUG: begin components::convert_idxs_to_ptrs
Expand Down
60 changes: 60 additions & 0 deletions benchmark/test/reference/multi_vector_distributed.profile.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,24 @@ DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin partition::build_from_contiguous
DEBUG: end partition::build_from_contiguous
DEBUG: begin partition::build_starting_indices
DEBUG: end partition::build_starting_indices
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin partition::build_ranges_by_part
DEBUG: end partition::build_ranges_by_part
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::prefix_sum_nonnegative
DEBUG: end components::prefix_sum_nonnegative
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin partition::build_ranges_from_global_size
Expand All @@ -34,12 +46,24 @@ DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin partition::build_from_contiguous
DEBUG: end partition::build_from_contiguous
DEBUG: begin partition::build_starting_indices
DEBUG: end partition::build_starting_indices
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin partition::build_ranges_by_part
DEBUG: end partition::build_ranges_by_part
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::prefix_sum_nonnegative
DEBUG: end components::prefix_sum_nonnegative
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin dense::fill
Expand All @@ -61,12 +85,24 @@ DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin partition::build_from_contiguous
DEBUG: end partition::build_from_contiguous
DEBUG: begin partition::build_starting_indices
DEBUG: end partition::build_starting_indices
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin partition::build_ranges_by_part
DEBUG: end partition::build_ranges_by_part
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::prefix_sum_nonnegative
DEBUG: end components::prefix_sum_nonnegative
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin partition::build_ranges_from_global_size
Expand All @@ -79,12 +115,24 @@ DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin partition::build_from_contiguous
DEBUG: end partition::build_from_contiguous
DEBUG: begin partition::build_starting_indices
DEBUG: end partition::build_starting_indices
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin partition::build_ranges_by_part
DEBUG: end partition::build_ranges_by_part
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::prefix_sum_nonnegative
DEBUG: end components::prefix_sum_nonnegative
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin dense::fill
Expand All @@ -110,12 +158,24 @@ DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin partition::build_from_contiguous
DEBUG: end partition::build_from_contiguous
DEBUG: begin partition::build_starting_indices
DEBUG: end partition::build_starting_indices
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin partition::build_ranges_by_part
DEBUG: end partition::build_ranges_by_part
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::prefix_sum_nonnegative
DEBUG: end components::prefix_sum_nonnegative
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin dense::fill
Expand Down
42 changes: 40 additions & 2 deletions benchmark/test/reference/spmv_distributed.profile.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,24 @@ DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin partition::build_from_contiguous
DEBUG: end partition::build_from_contiguous
DEBUG: begin partition::build_starting_indices
DEBUG: end partition::build_starting_indices
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin partition::build_ranges_by_part
DEBUG: end partition::build_ranges_by_part
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::prefix_sum_nonnegative
DEBUG: end components::prefix_sum_nonnegative
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::aos_to_soa
Expand All @@ -38,12 +50,24 @@ DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin partition::build_from_contiguous
DEBUG: end partition::build_from_contiguous
DEBUG: begin partition::build_starting_indices
DEBUG: end partition::build_starting_indices
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin partition::build_ranges_by_part
DEBUG: end partition::build_ranges_by_part
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::prefix_sum_nonnegative
DEBUG: end components::prefix_sum_nonnegative
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::aos_to_soa
Expand All @@ -66,12 +90,28 @@ DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin partition::build_from_contiguous
DEBUG: end partition::build_from_contiguous
DEBUG: begin partition::build_starting_indices
DEBUG: end partition::build_starting_indices
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin partition::build_ranges_by_part
DEBUG: end partition::build_ranges_by_part
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::prefix_sum_nonnegative
DEBUG: end components::prefix_sum_nonnegative
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
DEBUG: end components::fill_array
DEBUG: begin components::fill_array
Expand Down Expand Up @@ -134,8 +174,6 @@ DEBUG: begin copy
DEBUG: end copy
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin copy
DEBUG: end copy
DEBUG: begin components::convert_idxs_to_ptrs
DEBUG: end components::convert_idxs_to_ptrs
DEBUG: begin components::convert_idxs_to_ptrs
Expand Down
2 changes: 1 addition & 1 deletion benchmark/test/reference/spmv_distributed.profile.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"comm_pattern": "stencil",
"spmv": {
"csr-csr": {
"storage": 6420,
"storage": 6692,
"time": 1.0,
"repetitions": 1,
"completed": true
Expand Down
2 changes: 1 addition & 1 deletion benchmark/test/reference/spmv_distributed.simple.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"comm_pattern": "stencil",
"spmv": {
"csr-csr": {
"storage": 6420,
"storage": 6692,
"max_relative_norm2": 1.0,
"time": 1.0,
"repetitions": 10,
Expand Down
84 changes: 84 additions & 0 deletions common/cuda_hip/distributed/index_map_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,90 @@ GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_INDEX_MAP_MAP_TO_LOCAL);


template <typename LocalIndexType, typename GlobalIndexType>
void map_to_global(
std::shared_ptr<const DefaultExecutor> exec,
device_partition<const LocalIndexType, const GlobalIndexType> partition,
device_segmented_array<const GlobalIndexType> remote_global_idxs,
experimental::distributed::comm_index_type rank,
const array<LocalIndexType>& local_ids,
experimental::distributed::index_space is,
array<GlobalIndexType>& global_ids)
{
auto range_bounds = partition.offsets_begin;
auto starting_indices = partition.starting_indices_begin;
const auto& ranges_by_part = partition.ranges_by_part;
auto local_ids_it = local_ids.get_const_data();
auto input_size = local_ids.get_size();

auto policy = thrust_policy(exec);

global_ids.resize_and_reset(local_ids.get_size());
auto global_ids_it = global_ids.get_data();

auto map_local = [rank, ranges_by_part, range_bounds, starting_indices,
partition] __device__(auto lid) {
auto local_size =
static_cast<LocalIndexType>(partition.part_sizes_begin[rank]);

if (lid < 0 || lid >= local_size) {
return invalid_index<GlobalIndexType>();
}

auto local_ranges = ranges_by_part.get_segment(rank);
auto local_ranges_size =
static_cast<int64>(local_ranges.end - local_ranges.begin);

auto it = binary_search(int64(0), local_ranges_size, [=](const auto i) {
return starting_indices[local_ranges.begin[i]] >= lid;
});
auto local_range_id =
it != local_ranges_size ? it : max(int64(0), it - 1);
auto range_id = local_ranges.begin[local_range_id];

return static_cast<GlobalIndexType>(lid - starting_indices[range_id]) +
range_bounds[range_id];
};
auto map_non_local = [remote_global_idxs] __device__(auto lid) {
auto remote_size = static_cast<LocalIndexType>(
remote_global_idxs.flat_end - remote_global_idxs.flat_begin);

if (lid < 0 || lid >= remote_size) {
return invalid_index<GlobalIndexType>();
}

return remote_global_idxs.flat_begin[lid];
};
auto map_combined = [map_local, map_non_local, partition,
rank] __device__(auto lid) {
auto local_size =
static_cast<LocalIndexType>(partition.part_sizes_begin[rank]);

if (lid < local_size) {
return map_local(lid);
} else {
return map_non_local(lid - local_size);
}
};

if (is == experimental::distributed::index_space::local) {
thrust::transform(policy, local_ids_it, local_ids_it + input_size,
global_ids_it, map_local);
}
if (is == experimental::distributed::index_space::non_local) {
thrust::transform(policy, local_ids_it, local_ids_it + input_size,
global_ids_it, map_non_local);
}
if (is == experimental::distributed::index_space::combined) {
thrust::transform(policy, local_ids_it, local_ids_it + input_size,
global_ids_it, map_combined);
}
}

GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_INDEX_MAP_MAP_TO_GLOBAL);


} // namespace index_map
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
Expand Down
Loading
Loading