Skip to content

Commit

Permalink
[dist-mat] use row-gatherer
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed May 3, 2024
1 parent 3fe67e1 commit d5fcbfc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 156 deletions.
182 changes: 44 additions & 138 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


#include <ginkgo/core/base/precision_dispatch.hpp>
#include <ginkgo/core/distributed/neighborhood_communicator.hpp>
#include <ginkgo/core/distributed/vector.hpp>
#include <ginkgo/core/matrix/csr.hpp>

Expand Down Expand Up @@ -45,14 +46,10 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
: EnableDistributedLinOp<
Matrix<value_type, local_index_type, global_index_type>>{exec},
DistributedBase{comm},
send_offsets_(comm.size() + 1),
send_sizes_(comm.size()),
recv_offsets_(comm.size() + 1),
recv_sizes_(comm.size()),
gather_idxs_{exec},
one_scalar_{},
local_mtx_{local_matrix_template->clone(exec)},
non_local_mtx_{non_local_matrix_template->clone(exec)},
row_gatherer_{RowGatherer<LocalIndexType>::create(exec, comm)},
imap_{exec}
{
GKO_ASSERT(
Expand All @@ -73,13 +70,9 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
: EnableDistributedLinOp<
Matrix<value_type, local_index_type, global_index_type>>{exec},
DistributedBase{comm},
send_offsets_(comm.size() + 1),
send_sizes_(comm.size()),
recv_offsets_(comm.size() + 1),
recv_sizes_(comm.size()),
gather_idxs_{exec},
one_scalar_{},
imap_{exec}
row_gatherer_{RowGatherer<LocalIndexType>::create(exec, comm)},
imap_{exec},
one_scalar_{}
{
this->set_size(size);
one_scalar_.init(exec, dim<2>{1, 1});
Expand Down Expand Up @@ -138,11 +131,7 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(
result->get_communicator().size());
result->local_mtx_->copy_from(this->local_mtx_);
result->non_local_mtx_->copy_from(this->non_local_mtx_);
result->gather_idxs_ = this->gather_idxs_;
result->send_offsets_ = this->send_offsets_;
result->recv_offsets_ = this->recv_offsets_;
result->recv_sizes_ = this->recv_sizes_;
result->send_sizes_ = this->send_sizes_;
result->row_gatherer_->copy_from(this->row_gatherer_);
result->set_size(this->get_size());
}

Expand All @@ -156,11 +145,7 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
result->get_communicator().size());
result->local_mtx_->move_from(this->local_mtx_);
result->non_local_mtx_->move_from(this->non_local_mtx_);
result->gather_idxs_ = std::move(this->gather_idxs_);
result->send_offsets_ = std::move(this->send_offsets_);
result->recv_offsets_ = std::move(this->recv_offsets_);
result->recv_sizes_ = std::move(this->recv_sizes_);
result->send_sizes_ = std::move(this->send_sizes_);
result->row_gatherer_->move_from(this->row_gatherer_);
result->set_size(this->get_size());
this->set_size({});
}
Expand All @@ -183,7 +168,6 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
auto local_part = comm.rank();

// set up LinOp sizes
auto num_parts = static_cast<size_type>(row_partition->get_num_parts());
auto global_num_rows = row_partition->get_size();
auto global_num_cols = col_partition->get_size();
dim<2> global_dim{global_num_rows, global_num_cols};
Expand Down Expand Up @@ -230,44 +214,9 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
as<ReadableFromMatrixData<ValueType, LocalIndexType>>(this->non_local_mtx_)
->read(std::move(non_local_data));

// exchange step 1: determine recv_sizes, send_sizes, send_offsets
auto host_recv_targets = make_temporary_clone(
exec->get_master(), &imap_.get_remote_target_ids());
auto host_offsets = make_temporary_clone(
exec->get_master(), &imap_.get_remote_global_idxs().get_offsets());
std::fill(recv_sizes_.begin(), recv_sizes_.end(), 0);
for (size_type i = 0; i < host_recv_targets->get_size(); ++i) {
recv_sizes_[host_recv_targets->get_const_data()[i]] =
host_offsets->get_const_data()[i + 1] -
host_offsets->get_const_data()[i];
}
std::partial_sum(recv_sizes_.begin(), recv_sizes_.end(),
recv_offsets_.begin() + 1);
comm.all_to_all(exec, recv_sizes_.data(), 1, send_sizes_.data(), 1);
std::partial_sum(send_sizes_.begin(), send_sizes_.end(),
send_offsets_.begin() + 1);
send_offsets_[0] = 0;
recv_offsets_[0] = 0;

// exchange step 2: exchange gather_idxs from receivers to senders
auto recv_gather_idxs =
make_const_array_view(imap_.get_executor(), imap_.get_non_local_size(),
imap_.get_remote_local_idxs().get_flat_data())
.copy_to_array();
auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
if (use_host_buffer) {
recv_gather_idxs.set_executor(exec->get_master());
gather_idxs_.clear();
gather_idxs_.set_executor(exec->get_master());
}
gather_idxs_.resize_and_reset(send_offsets_.back());
comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec,
recv_gather_idxs.get_const_data(), recv_sizes_.data(),
recv_offsets_.data(), gather_idxs_.get_data(),
send_sizes_.data(), send_offsets_.data());
if (use_host_buffer) {
gather_idxs_.set_executor(exec);
}
row_gatherer_ = RowGatherer<local_index_type>::create(
exec, std::make_shared<mpi::neighborhood_communicator>(comm, imap_),
imap_);
}


Expand Down Expand Up @@ -309,53 +258,6 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
mpi::request Matrix<ValueType, LocalIndexType, GlobalIndexType>::communicate(
const local_vector_type* local_b) const
{
if (!non_local_mtx_) {
return {};
}
auto exec = this->get_executor();
const auto comm = this->get_communicator();
auto num_cols = local_b->get_size()[1];
auto send_size = send_offsets_.back();
auto recv_size = recv_offsets_.back();
auto send_dim = dim<2>{static_cast<size_type>(send_size), num_cols};
auto recv_dim = dim<2>{static_cast<size_type>(recv_size), num_cols};
recv_buffer_.init(exec, recv_dim);
send_buffer_.init(exec, send_dim);

local_b->row_gather(&gather_idxs_, send_buffer_.get());

auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
if (use_host_buffer) {
host_recv_buffer_.init(exec->get_master(), recv_dim);
host_send_buffer_.init(exec->get_master(), send_dim);
host_send_buffer_->copy_from(send_buffer_.get());
}

mpi::contiguous_type type(num_cols, mpi::type_impl<ValueType>::get_type());
auto send_ptr = use_host_buffer ? host_send_buffer_->get_const_values()
: send_buffer_->get_const_values();
auto recv_ptr = use_host_buffer ? host_recv_buffer_->get_values()
: recv_buffer_->get_values();
exec->synchronize();
#ifdef GINKGO_FORCE_SPMV_BLOCKING_COMM
comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr,
send_sizes_.data(), send_offsets_.data(), type.get(),
recv_ptr, recv_sizes_.data(), recv_offsets_.data(),
type.get());
return {};
#else
return comm.i_all_to_all_v(
use_host_buffer ? exec->get_master() : exec, send_ptr,
send_sizes_.data(), send_offsets_.data(), type.get(), recv_ptr,
recv_sizes_.data(), recv_offsets_.data(), type.get());
#endif
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
const LinOp* b, LinOp* x) const
Expand All @@ -371,20 +273,24 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
dense_x->get_local_values()),
dense_x->get_local_vector()->get_stride());

auto exec = this->get_executor();
auto comm = this->get_communicator();
auto req = this->communicate(dense_b->get_local_vector());
auto recv_dim =
dim<2>{static_cast<size_type>(
row_gatherer_->get_collective_communicator()
->get_recv_size()),
dense_b->get_size()[1]};
auto recv_exec = mpi::requires_host_buffer(exec, comm)
? exec->get_master()
: exec;
recv_buffer_.init(recv_exec, recv_dim);
auto req =
this->row_gatherer_->apply_async(dense_b, recv_buffer_.get());
local_mtx_->apply(dense_b->get_local_vector(), local_x);
req.wait();

if (non_local_mtx_) {
auto exec = this->get_executor();
auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
if (use_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
}
non_local_mtx_->apply(one_scalar_.get(), recv_buffer_.get(),
one_scalar_.get(), local_x);
}
non_local_mtx_->apply(one_scalar_.get(), recv_buffer_.get(),
one_scalar_.get(), local_x);
},
b, x);
}
Expand All @@ -406,21 +312,25 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
dense_x->get_local_values()),
dense_x->get_local_vector()->get_stride());

auto exec = this->get_executor();
auto comm = this->get_communicator();
auto req = this->communicate(dense_b->get_local_vector());
auto recv_dim =
dim<2>{static_cast<size_type>(
row_gatherer_->get_collective_communicator()
->get_recv_size()),
dense_b->get_size()[1]};
auto recv_exec = mpi::requires_host_buffer(exec, comm)
? exec->get_master()
: exec;
recv_buffer_.init(recv_exec, recv_dim);
auto req =
this->row_gatherer_->apply_async(dense_b, recv_buffer_.get());
local_mtx_->apply(local_alpha, dense_b->get_local_vector(),
local_beta, local_x);
req.wait();

if (non_local_mtx_) {
auto exec = this->get_executor();
auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
if (use_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
}
non_local_mtx_->apply(local_alpha, recv_buffer_.get(),
one_scalar_.get(), local_x);
}
non_local_mtx_->apply(local_alpha, recv_buffer_.get(),
one_scalar_.get(), local_x);
},
alpha, b, beta, x);
}
Expand All @@ -431,6 +341,8 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(const Matrix& other)
: EnableDistributedLinOp<Matrix<value_type, local_index_type,
global_index_type>>{other.get_executor()},
DistributedBase{other.get_communicator()},
row_gatherer_{RowGatherer<LocalIndexType>::create(
other.get_executor(), other.get_communicator())},
imap_{other.get_executor()}
{
*this = other;
Expand All @@ -443,6 +355,8 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
: EnableDistributedLinOp<Matrix<value_type, local_index_type,
global_index_type>>{other.get_executor()},
DistributedBase{other.get_communicator()},
row_gatherer_{RowGatherer<LocalIndexType>::create(
other.get_executor(), other.get_communicator())},
imap_{other.get_executor()}
{
*this = std::move(other);
Expand All @@ -460,11 +374,7 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::operator=(
this->set_size(other.get_size());
local_mtx_->copy_from(other.local_mtx_);
non_local_mtx_->copy_from(other.non_local_mtx_);
gather_idxs_ = other.gather_idxs_;
send_offsets_ = other.send_offsets_;
recv_offsets_ = other.recv_offsets_;
send_sizes_ = other.send_sizes_;
recv_sizes_ = other.recv_sizes_;
row_gatherer_->copy_from(other.row_gatherer_);
imap_ = other.imap_;
one_scalar_.init(this->get_executor(), dim<2>{1, 1});
one_scalar_->fill(one<value_type>());
Expand All @@ -484,11 +394,7 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::operator=(Matrix&& other)
other.set_size({});
local_mtx_->move_from(other.local_mtx_);
non_local_mtx_->move_from(other.non_local_mtx_);
gather_idxs_ = std::move(other.gather_idxs_);
send_offsets_ = std::move(other.send_offsets_);
recv_offsets_ = std::move(other.recv_offsets_);
send_sizes_ = std::move(other.send_sizes_);
recv_sizes_ = std::move(other.recv_sizes_);
row_gatherer_->move_from(other.row_gatherer_);
imap_ = std::move(other.imap_);
one_scalar_.init(this->get_executor(), dim<2>{1, 1});
one_scalar_->fill(one<value_type>());
Expand Down
26 changes: 8 additions & 18 deletions include/ginkgo/core/distributed/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ginkgo/core/distributed/base.hpp>
#include <ginkgo/core/distributed/index_map.hpp>
#include <ginkgo/core/distributed/lin_op.hpp>
#include <ginkgo/core/distributed/row_gatherer.hpp>


namespace gko {
Expand Down Expand Up @@ -358,6 +359,12 @@ class Matrix
return non_local_mtx_;
}

std::shared_ptr<const RowGatherer<local_index_type>> get_row_gatherer()
const
{
return row_gatherer_;
}

const index_map<local_index_type, global_index_type>& get_index_map() const
{
return imap_;
Expand Down Expand Up @@ -554,32 +561,15 @@ class Matrix
mpi::communicator comm, dim<2> size,
std::shared_ptr<LinOp> local_linop);

/**
* Starts a non-blocking communication of the values of b that are shared
* with other processors.
*
* @param local_b The full local vector to be communicated. The subset of
* shared values is automatically extracted.
* @return MPI request for the non-blocking communication.
*/
mpi::request communicate(const local_vector_type* local_b) const;

void apply_impl(const LinOp* b, LinOp* x) const override;

void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
LinOp* x) const override;

private:
std::vector<comm_index_type> send_offsets_;
std::vector<comm_index_type> send_sizes_;
std::vector<comm_index_type> recv_offsets_;
std::vector<comm_index_type> recv_sizes_;
array<local_index_type> gather_idxs_;
std::shared_ptr<RowGatherer<LocalIndexType>> row_gatherer_;
index_map<local_index_type, global_index_type> imap_;
gko::detail::DenseCache<value_type> one_scalar_;
gko::detail::DenseCache<value_type> host_send_buffer_;
gko::detail::DenseCache<value_type> host_recv_buffer_;
gko::detail::DenseCache<value_type> send_buffer_;
gko::detail::DenseCache<value_type> recv_buffer_;
std::shared_ptr<LinOp> local_mtx_;
std::shared_ptr<LinOp> non_local_mtx_;
Expand Down

0 comments on commit d5fcbfc

Please sign in to comment.