diff --git a/core/distributed/matrix.cpp b/core/distributed/matrix.cpp index 8757fccb9dc..510362b469f 100644 --- a/core/distributed/matrix.cpp +++ b/core/distributed/matrix.cpp @@ -6,6 +6,7 @@ #include +#include #include #include @@ -45,14 +46,10 @@ Matrix::Matrix( : EnableDistributedLinOp< Matrix>{exec}, DistributedBase{comm}, - send_offsets_(comm.size() + 1), - send_sizes_(comm.size()), - recv_offsets_(comm.size() + 1), - recv_sizes_(comm.size()), - gather_idxs_{exec}, one_scalar_{}, local_mtx_{local_matrix_template->clone(exec)}, non_local_mtx_{non_local_matrix_template->clone(exec)}, + row_gatherer_{RowGatherer::create(exec, comm)}, imap_{exec} { GKO_ASSERT( @@ -73,13 +70,9 @@ Matrix::Matrix( : EnableDistributedLinOp< Matrix>{exec}, DistributedBase{comm}, - send_offsets_(comm.size() + 1), - send_sizes_(comm.size()), - recv_offsets_(comm.size() + 1), - recv_sizes_(comm.size()), - gather_idxs_{exec}, - one_scalar_{}, - imap_{exec} + row_gatherer_{RowGatherer::create(exec, comm)}, + imap_{exec}, + one_scalar_{} { this->set_size(size); one_scalar_.init(exec, dim<2>{1, 1}); @@ -138,11 +131,7 @@ void Matrix::convert_to( result->get_communicator().size()); result->local_mtx_->copy_from(this->local_mtx_); result->non_local_mtx_->copy_from(this->non_local_mtx_); - result->gather_idxs_ = this->gather_idxs_; - result->send_offsets_ = this->send_offsets_; - result->recv_offsets_ = this->recv_offsets_; - result->recv_sizes_ = this->recv_sizes_; - result->send_sizes_ = this->send_sizes_; + result->row_gatherer_->copy_from(this->row_gatherer_); result->set_size(this->get_size()); } @@ -156,11 +145,7 @@ void Matrix::move_to( result->get_communicator().size()); result->local_mtx_->move_from(this->local_mtx_); result->non_local_mtx_->move_from(this->non_local_mtx_); - result->gather_idxs_ = std::move(this->gather_idxs_); - result->send_offsets_ = std::move(this->send_offsets_); - result->recv_offsets_ = std::move(this->recv_offsets_); - result->recv_sizes_ = std::move(this->recv_sizes_); - result->send_sizes_ = std::move(this->send_sizes_); + result->row_gatherer_->move_from(this->row_gatherer_); result->set_size(this->get_size()); this->set_size({}); } @@ -183,7 +168,6 @@ void Matrix::read_distributed( auto local_part = comm.rank(); // set up LinOp sizes - auto num_parts = static_cast(row_partition->get_num_parts()); auto global_num_rows = row_partition->get_size(); auto global_num_cols = col_partition->get_size(); dim<2> global_dim{global_num_rows, global_num_cols}; @@ -230,44 +214,9 @@ void Matrix::read_distributed( as>(this->non_local_mtx_) ->read(std::move(non_local_data)); - // exchange step 1: determine recv_sizes, send_sizes, send_offsets - auto host_recv_targets = make_temporary_clone( - exec->get_master(), &imap_.get_remote_target_ids()); - auto host_offsets = make_temporary_clone( - exec->get_master(), &imap_.get_remote_global_idxs().get_offsets()); - std::fill(recv_sizes_.begin(), recv_sizes_.end(), 0); - for (size_type i = 0; i < host_recv_targets->get_size(); ++i) { - recv_sizes_[host_recv_targets->get_const_data()[i]] = - host_offsets->get_const_data()[i + 1] - - host_offsets->get_const_data()[i]; - } - std::partial_sum(recv_sizes_.begin(), recv_sizes_.end(), - recv_offsets_.begin() + 1); - comm.all_to_all(exec, recv_sizes_.data(), 1, send_sizes_.data(), 1); - std::partial_sum(send_sizes_.begin(), send_sizes_.end(), - send_offsets_.begin() + 1); - send_offsets_[0] = 0; - recv_offsets_[0] = 0; - - // exchange step 2: exchange gather_idxs from receivers to senders - auto recv_gather_idxs = - make_const_array_view(imap_.get_executor(), imap_.get_non_local_size(), - imap_.get_remote_local_idxs().get_flat_data()) - .copy_to_array(); - auto use_host_buffer = mpi::requires_host_buffer(exec, comm); - if (use_host_buffer) { - recv_gather_idxs.set_executor(exec->get_master()); - gather_idxs_.clear(); - gather_idxs_.set_executor(exec->get_master()); - } - gather_idxs_.resize_and_reset(send_offsets_.back()); - comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, - recv_gather_idxs.get_const_data(), recv_sizes_.data(), - recv_offsets_.data(), gather_idxs_.get_data(), - send_sizes_.data(), send_offsets_.data()); - if (use_host_buffer) { - gather_idxs_.set_executor(exec); - } + row_gatherer_ = RowGatherer::create( + exec, std::make_shared(comm, imap_), + imap_); } @@ -309,53 +258,6 @@ void Matrix::read_distributed( } -template -mpi::request Matrix::communicate( - const local_vector_type* local_b) const -{ - if (!non_local_mtx_) { - return {}; - } - auto exec = this->get_executor(); - const auto comm = this->get_communicator(); - auto num_cols = local_b->get_size()[1]; - auto send_size = send_offsets_.back(); - auto recv_size = recv_offsets_.back(); - auto send_dim = dim<2>{static_cast(send_size), num_cols}; - auto recv_dim = dim<2>{static_cast(recv_size), num_cols}; - recv_buffer_.init(exec, recv_dim); - send_buffer_.init(exec, send_dim); - - local_b->row_gather(&gather_idxs_, send_buffer_.get()); - - auto use_host_buffer = mpi::requires_host_buffer(exec, comm); - if (use_host_buffer) { - host_recv_buffer_.init(exec->get_master(), recv_dim); - host_send_buffer_.init(exec->get_master(), send_dim); - host_send_buffer_->copy_from(send_buffer_.get()); - } - - mpi::contiguous_type type(num_cols, mpi::type_impl::get_type()); - auto send_ptr = use_host_buffer ? host_send_buffer_->get_const_values() - : send_buffer_->get_const_values(); - auto recv_ptr = use_host_buffer ? host_recv_buffer_->get_values() - : recv_buffer_->get_values(); - exec->synchronize(); -#ifdef GINKGO_FORCE_SPMV_BLOCKING_COMM - comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr, - send_sizes_.data(), send_offsets_.data(), type.get(), - recv_ptr, recv_sizes_.data(), recv_offsets_.data(), - type.get()); - return {}; -#else - return comm.i_all_to_all_v( - use_host_buffer ? exec->get_master() : exec, send_ptr, - send_sizes_.data(), send_offsets_.data(), type.get(), recv_ptr, - recv_sizes_.data(), recv_offsets_.data(), type.get()); -#endif -} - - template void Matrix::apply_impl( const LinOp* b, LinOp* x) const @@ -371,20 +273,24 @@ void Matrix::apply_impl( dense_x->get_local_values()), dense_x->get_local_vector()->get_stride()); + auto exec = this->get_executor(); auto comm = this->get_communicator(); - auto req = this->communicate(dense_b->get_local_vector()); + auto recv_dim = + dim<2>{static_cast( + row_gatherer_->get_collective_communicator() + ->get_recv_size()), + dense_b->get_size()[1]}; + auto recv_exec = mpi::requires_host_buffer(exec, comm) + ? exec->get_master() + : exec; + recv_buffer_.init(recv_exec, recv_dim); + auto req = + this->row_gatherer_->apply_async(dense_b, recv_buffer_.get()); local_mtx_->apply(dense_b->get_local_vector(), local_x); req.wait(); - if (non_local_mtx_) { - auto exec = this->get_executor(); - auto use_host_buffer = mpi::requires_host_buffer(exec, comm); - if (use_host_buffer) { - recv_buffer_->copy_from(host_recv_buffer_.get()); - } - non_local_mtx_->apply(one_scalar_.get(), recv_buffer_.get(), - one_scalar_.get(), local_x); - } + non_local_mtx_->apply(one_scalar_.get(), recv_buffer_.get(), + one_scalar_.get(), local_x); }, b, x); } @@ -406,21 +312,25 @@ void Matrix::apply_impl( dense_x->get_local_values()), dense_x->get_local_vector()->get_stride()); + auto exec = this->get_executor(); auto comm = this->get_communicator(); - auto req = this->communicate(dense_b->get_local_vector()); + auto recv_dim = + dim<2>{static_cast( + row_gatherer_->get_collective_communicator() + ->get_recv_size()), + dense_b->get_size()[1]}; + auto recv_exec = mpi::requires_host_buffer(exec, comm) + ? exec->get_master() + : exec; + recv_buffer_.init(recv_exec, recv_dim); + auto req = + this->row_gatherer_->apply_async(dense_b, recv_buffer_.get()); local_mtx_->apply(local_alpha, dense_b->get_local_vector(), local_beta, local_x); req.wait(); - if (non_local_mtx_) { - auto exec = this->get_executor(); - auto use_host_buffer = mpi::requires_host_buffer(exec, comm); - if (use_host_buffer) { - recv_buffer_->copy_from(host_recv_buffer_.get()); - } - non_local_mtx_->apply(local_alpha, recv_buffer_.get(), - one_scalar_.get(), local_x); - } + non_local_mtx_->apply(local_alpha, recv_buffer_.get(), + one_scalar_.get(), local_x); }, alpha, b, beta, x); } @@ -431,6 +341,8 @@ Matrix::Matrix(const Matrix& other) : EnableDistributedLinOp>{other.get_executor()}, DistributedBase{other.get_communicator()}, + row_gatherer_{RowGatherer::create( + other.get_executor(), other.get_communicator())}, imap_{other.get_executor()} { *this = other; @@ -443,6 +355,8 @@ Matrix::Matrix( : EnableDistributedLinOp>{other.get_executor()}, DistributedBase{other.get_communicator()}, + row_gatherer_{RowGatherer::create( + other.get_executor(), other.get_communicator())}, imap_{other.get_executor()} { *this = std::move(other); @@ -460,11 +374,7 @@ Matrix::operator=( this->set_size(other.get_size()); local_mtx_->copy_from(other.local_mtx_); non_local_mtx_->copy_from(other.non_local_mtx_); - gather_idxs_ = other.gather_idxs_; - send_offsets_ = other.send_offsets_; - recv_offsets_ = other.recv_offsets_; - send_sizes_ = other.send_sizes_; - recv_sizes_ = other.recv_sizes_; + row_gatherer_->copy_from(other.row_gatherer_); imap_ = other.imap_; one_scalar_.init(this->get_executor(), dim<2>{1, 1}); one_scalar_->fill(one()); @@ -484,11 +394,7 @@ Matrix::operator=(Matrix&& other) other.set_size({}); local_mtx_->move_from(other.local_mtx_); non_local_mtx_->move_from(other.non_local_mtx_); - gather_idxs_ = std::move(other.gather_idxs_); - send_offsets_ = std::move(other.send_offsets_); - recv_offsets_ = std::move(other.recv_offsets_); - send_sizes_ = std::move(other.send_sizes_); - recv_sizes_ = std::move(other.recv_sizes_); + row_gatherer_->move_from(other.row_gatherer_); imap_ = std::move(other.imap_); one_scalar_.init(this->get_executor(), dim<2>{1, 1}); one_scalar_->fill(one()); diff --git a/include/ginkgo/core/distributed/matrix.hpp b/include/ginkgo/core/distributed/matrix.hpp index 64cb2f6c948..751f326e4e4 100644 --- a/include/ginkgo/core/distributed/matrix.hpp +++ b/include/ginkgo/core/distributed/matrix.hpp @@ -17,6 +17,7 @@ #include #include #include +#include namespace gko { @@ -358,6 +359,12 @@ class Matrix return non_local_mtx_; } + std::shared_ptr> get_row_gatherer() + const + { + return row_gatherer_; + } + const index_map& get_index_map() const { return imap_; @@ -554,32 +561,15 @@ class Matrix mpi::communicator comm, dim<2> size, std::shared_ptr local_linop); - /** - * Starts a non-blocking communication of the values of b that are shared - * with other processors. - * - * @param local_b The full local vector to be communicated. The subset of - * shared values is automatically extracted. - * @return MPI request for the non-blocking communication. - */ - mpi::request communicate(const local_vector_type* local_b) const; - void apply_impl(const LinOp* b, LinOp* x) const override; void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta, LinOp* x) const override; private: - std::vector send_offsets_; - std::vector send_sizes_; - std::vector recv_offsets_; - std::vector recv_sizes_; - array gather_idxs_; + std::shared_ptr> row_gatherer_; index_map imap_; gko::detail::DenseCache one_scalar_; - gko::detail::DenseCache host_send_buffer_; - gko::detail::DenseCache host_recv_buffer_; - gko::detail::DenseCache send_buffer_; gko::detail::DenseCache recv_buffer_; std::shared_ptr local_mtx_; std::shared_ptr non_local_mtx_;