Skip to content

Commit

Permalink
Add device implementation of DD matrix kernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzgoebel committed Nov 4, 2024
1 parent 1c0e4e9 commit a3fb354
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 4 deletions.
89 changes: 88 additions & 1 deletion common/cuda_hip/distributed/dd_matrix_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,94 @@ void filter_non_owning_idxs(
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
col_partition,
comm_index_type local_part, array<GlobalIndexType>& non_local_row_idxs,
array<GlobalIndexType>& non_local_col_idxs) GKO_NOT_IMPLEMENTED;
array<GlobalIndexType>& non_local_col_idxs)
{
auto input_vals = input.get_const_values();
auto row_part_ids = row_partition->get_part_ids();
auto col_part_ids = col_partition->get_part_ids();
const auto* row_range_bounds = row_partition->get_range_bounds();
const auto* col_range_bounds = col_partition->get_range_bounds();
const auto* row_range_starting_indices =
row_partition->get_range_starting_indices();
const auto* col_range_starting_indices =
col_partition->get_range_starting_indices();
const auto num_row_ranges = row_partition->get_num_ranges();
const auto num_col_ranges = col_partition->get_num_ranges();
const auto num_input_elements = input.get_num_stored_elements();

auto policy = thrust_policy(exec);

// precompute the row and column range id of each input element
auto input_row_idxs = input.get_const_row_idxs();
auto input_col_idxs = input.get_const_col_idxs();
array<size_type> row_range_ids{exec, num_input_elements};
thrust::upper_bound(policy, row_range_bounds + 1,
row_range_bounds + num_row_ranges + 1, input_row_idxs,
input_row_idxs + num_input_elements,
row_range_ids.get_data());
array<size_type> col_range_ids{exec, input.get_num_stored_elements()};
thrust::upper_bound(policy, col_range_bounds + 1,
col_range_bounds + num_col_ranges + 1, input_col_idxs,
input_col_idxs + num_input_elements,
col_range_ids.get_data());

// count number of non local row and column indices.
auto range_ids_it = thrust::make_zip_iterator(thrust::make_tuple(
row_range_ids.get_const_data(), col_range_ids.get_const_data()));
auto num_elements_pair = thrust::transform_reduce(
policy, range_ids_it, range_ids_it + num_input_elements,
[local_part, row_part_ids, col_part_ids] __host__ __device__(
const thrust::tuple<size_type, size_type>& tuple) {
auto row_part = row_part_ids[thrust::get<0>(tuple)];
auto col_part = col_part_ids[thrust::get<1>(tuple)];
bool is_local_row = row_part == local_part;
bool is_local_col = col_part == local_part;
return thrust::make_tuple(
is_local_row ? size_type{0} : size_type{1},
is_local_col ? size_type{0} : size_type{1});
},
thrust::make_tuple(size_type{}, size_type{}),
[] __host__ __device__(const thrust::tuple<size_type, size_type>& a,
const thrust::tuple<size_type, size_type>& b) {
return thrust::make_tuple(thrust::get<0>(a) + thrust::get<0>(b),
thrust::get<1>(a) + thrust::get<1>(b));
});
auto n_non_local_col_idxs = thrust::get<0>(num_elements_pair);
auto n_non_local_row_idxs = thrust::get<1>(num_elements_pair);

// define global-to-local maps for row and column indices
auto map_to_local_row =
[row_range_bounds, row_range_starting_indices] __host__ __device__(
const GlobalIndexType row, const size_type range_id) {
return static_cast<LocalIndexType>(row -
row_range_bounds[range_id]) +
row_range_starting_indices[range_id];
};
auto map_to_local_col =
[col_range_bounds, col_range_starting_indices] __host__ __device__(
const GlobalIndexType col, const size_type range_id) {
return static_cast<LocalIndexType>(col -
col_range_bounds[range_id]) +
col_range_starting_indices[range_id];
};

non_local_col_idxs.resize_and_reset(n_non_local_col_idxs);
non_local_row_idxs.resize_and_reset(n_non_local_row_idxs);
thrust::copy_if(policy, input_col_idxs, input_col_idxs + num_input_elements,
range_ids_it, non_local_col_idxs.get_data(),
[local_part, col_part_ids] __host__ __device__(
const thrust::tuple<size_type, size_type>& tuple) {
auto col_part = col_part_ids[thrust::get<1>(tuple)];
return col_part != local_part;
});
thrust::copy_if(policy, input_row_idxs, input_row_idxs + num_input_elements,
range_ids_it, non_local_row_idxs.get_data(),
[local_part, row_part_ids] __host__ __device__(
const thrust::tuple<size_type, size_type>& tuple) {
auto row_part = row_part_ids[thrust::get<0>(tuple)];
return row_part != local_part;
});
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_FILTER_NON_OWNING_IDXS);
Expand Down
4 changes: 3 additions & 1 deletion core/distributed/dd_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,10 @@ void DdMatrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(

// Gather local sizes from all ranks and build the partition in the enriched
// space.
array<GlobalIndexType> range_bounds{exec, num_parts + 1};
array<GlobalIndexType> range_bounds{
use_host_buffer ? exec->get_master() : exec, num_parts + 1};
comm.all_gather(exec, &local_num_rows, 1, range_bounds.get_data(), 1);
range_bounds.set_executor(exec);
exec->run(dd_matrix::make_prefix_sum_nonnegative(range_bounds.get_data(),
num_parts + 1));
auto large_partition =
Expand Down
90 changes: 88 additions & 2 deletions omp/distributed/dd_matrix_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,94 @@ void filter_non_owning_idxs(
row_partition,
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
col_partition,
comm_index_type local_part, array<GlobalIndexType>& non_local_row_idxs,
array<GlobalIndexType>& non_local_col_idxs) GKO_NOT_IMPLEMENTED;
comm_index_type local_part, array<GlobalIndexType>& non_owning_row_idxs,
array<GlobalIndexType>& non_owning_col_idxs)
{
auto input_row_idxs = input.get_const_row_idxs();
auto input_col_idxs = input.get_const_col_idxs();
auto input_vals = input.get_const_values();
auto row_part_ids = row_partition->get_part_ids();
auto col_part_ids = col_partition->get_part_ids();
auto num_parts = row_partition->get_num_parts();
size_type row_range_id_hint = 0;
size_type col_range_id_hint = 0;

// store non-local entries with global column idxs
vector<GlobalIndexType> non_local_row_idxs(exec);
vector<GlobalIndexType> non_local_col_idxs(exec);

auto num_threads = static_cast<size_type>(omp_get_max_threads());
auto num_input = input.get_num_stored_elements();
auto size_per_thread = (num_input + num_threads - 1) / num_threads;
vector<size_type> non_local_col_offsets(num_threads, 0, exec);
vector<size_type> non_local_row_offsets(num_threads, 0, exec);

#pragma omp parallel firstprivate(col_range_id_hint, row_range_id_hint)
{
vector<GlobalIndexType> thread_non_local_col_idxs(exec);
vector<GlobalIndexType> thread_non_local_row_idxs(exec);
auto thread_id = omp_get_thread_num();
auto thread_begin = thread_id * size_per_thread;
auto thread_end = std::min(thread_begin + size_per_thread, num_input);
// Count non local row and colunm idxs per thread
for (size_type i = thread_begin; i < thread_end; i++) {
auto global_col = input_col_idxs[i];
auto global_row = input_row_idxs[i];
col_range_id_hint =
find_range(global_col, col_partition, col_range_id_hint);
row_range_id_hint =
find_range(global_row, row_partition, row_range_id_hint);
if (col_part_ids[col_range_id_hint] != local_part) {
thread_non_local_col_idxs.push_back(global_col);
}
if (row_part_ids[row_range_id_hint] != local_part) {
thread_non_local_row_idxs.push_back(global_row);
}
}
non_local_col_offsets[thread_id] = thread_non_local_col_idxs.size();
non_local_row_offsets[thread_id] = thread_non_local_row_idxs.size();

#pragma omp barrier
#pragma omp single
{
// assign output ranges to the individual threads
size_type n_non_local_col_idxs{};
size_type n_non_local_row_idxs{};
for (size_type thread = 0; thread < num_threads; thread++) {
auto size_col_idxs = non_local_col_offsets[thread];
auto size_row_idxs = non_local_row_offsets[thread];
non_local_col_offsets[thread] = n_non_local_col_idxs;
non_local_row_offsets[thread] = n_non_local_row_idxs;
n_non_local_col_idxs += size_col_idxs;
n_non_local_row_idxs += size_row_idxs;
}
non_local_col_idxs.resize(n_non_local_col_idxs);
non_local_row_idxs.resize(n_non_local_row_idxs);
}
// write back the non_local idxs to the output ranges
auto col_counter = non_local_col_offsets[thread_id];
auto row_counter = non_local_row_offsets[thread_id];
for (const auto& non_local_col : thread_non_local_col_idxs) {
non_local_col_idxs[col_counter] = non_local_col;
col_counter++;
}
for (const auto& non_local_row : thread_non_local_row_idxs) {
non_local_row_idxs[row_counter] = non_local_row;
row_counter++;
}
}

non_owning_col_idxs.resize_and_reset(non_local_col_idxs.size());
#pragma omp parallel for
for (size_type i = 0; i < non_local_col_idxs.size(); i++) {
non_owning_col_idxs.get_data()[i] = non_local_col_idxs[i];
}
non_owning_row_idxs.resize_and_reset(non_local_row_idxs.size());
#pragma omp parallel for
for (size_type i = 0; i < non_local_row_idxs.size(); i++) {
non_owning_row_idxs.get_data()[i] = non_local_row_idxs[i];
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_FILTER_NON_OWNING_IDXS);
Expand Down

0 comments on commit a3fb354

Please sign in to comment.