diff --git a/core/distributed/matrix.cpp b/core/distributed/matrix.cpp index 1dcddbd1a6a..1da4fc1604c 100644 --- a/core/distributed/matrix.cpp +++ b/core/distributed/matrix.cpp @@ -238,6 +238,40 @@ void Matrix::move_to( } +template +void Matrix::read_distributed( + const device_matrix_data& local_data, + const device_matrix_data& non_local_data, + const std::vector& send_offsets, + const std::vector& send_sizes, + const std::vector& recv_offsets, + const std::vector& recv_sizes, + const array& gather_idxs) +{ + auto exec = this->get_executor(); + const auto comm = this->get_communicator(); + + GKO_ASSERT_EQUAL_ROWS(local_data.get_size(), non_local_data.get_size()); + + as>(local_mtx_) + ->read(std::move(local_data)); + as>(non_local_mtx_) + ->read(std::move(non_local_data)); + + auto num_rows = local_mtx_->get_size()[0]; + auto num_cols = local_mtx_->get_size()[1]; + comm.all_reduce(exec, &num_rows, 1, MPI_SUM); + comm.all_reduce(exec, &num_cols, 1, MPI_SUM); + this->set_size({num_rows, num_cols}); + + send_offsets_ = send_offsets; + send_sizes_ = send_sizes; + recv_offsets_ = recv_offsets; + recv_sizes_ = recv_sizes; + gather_idxs_ = gather_idxs; +} + + template void Matrix::read_distributed( const device_matrix_data& data, diff --git a/include/ginkgo/core/distributed/matrix.hpp b/include/ginkgo/core/distributed/matrix.hpp index 4689c3d3381..d7bca9b91ac 100644 --- a/include/ginkgo/core/distributed/matrix.hpp +++ b/include/ginkgo/core/distributed/matrix.hpp @@ -298,6 +298,29 @@ class Matrix std::shared_ptr> partition); + /** + * Reads a matrix that is split into local data that only operates on local + * DOFs and non_local_data that needs input from non-local DOFs. + * + * local_data and non_local_data must have the same number of rows. + * Additionally, it is assumed that no column of non_local_data is + * completely zero. The number of columns in non_local_data must the same as + * the number of receiving indices in the communication pattern. + * + * @param local_data The matrix data for the local block. + * @param non_local_data The matrix data for the non-local block. + * @param sparse_comm The communication pattern + */ + void read_distributed( + const device_matrix_data& local_data, + const device_matrix_data& non_local_data, + const std::vector& send_offsets, + const std::vector& send_sizes, + const std::vector& recv_offsets, + const std::vector& recv_sizes, + const array& gather_idxs); + + /** * Reads a square matrix from the matrix_data structure and a global * partition. diff --git a/test/mpi/matrix.cpp b/test/mpi/matrix.cpp index d836eb008d9..7dda19ee624 100644 --- a/test/mpi/matrix.cpp +++ b/test/mpi/matrix.cpp @@ -156,6 +156,54 @@ TYPED_TEST(MatrixCreation, ReadsDistributedWithColPartition) } +TYPED_TEST(MatrixCreation, ReadsDistributedLocalAndNonLocalData) +{ + // 0-2 2-4 4-5 + using value_type = typename TestFixture::value_type; + using local_index_type = typename TestFixture::local_index_type; + using comm_index_type = gko::experimental::distributed::comm_index_type; + using csr = typename TestFixture::local_matrix_type; + using dist_mtx_type = typename TestFixture::dist_mtx_type; + + this->dist_mat->read_distributed(this->mat_input, this->row_part); + + auto exec = this->dist_mat->get_executor(); + auto comm = this->dist_mat->get_communicator(); + + gko::matrix_data local_host_matrix; + gko::as(this->dist_mat->get_local_matrix())->write(local_host_matrix); + auto local_matrix = + gko::device_matrix_data::create_from_host( + exec, local_host_matrix); + + gko::matrix_data non_local_host_matrix; + gko::as(this->dist_mat->get_non_local_matrix()) + ->write(non_local_host_matrix); + auto non_local_matrix = + gko::device_matrix_data::create_from_host( + exec, non_local_host_matrix); + + // offsets and sizes and gather_idxs are only used during apply + // thus we use dummies here + std::vector send_offsets{}; + std::vector recv_offsets{}; + std::vector send_sizes{}; + std::vector recv_sizes{}; + gko::array gather_idxs(exec); + + auto dist_mat_b = dist_mtx_type::create(exec, comm); + + dist_mat_b->read_distributed(local_matrix, non_local_matrix, send_offsets, + send_sizes, recv_offsets, recv_sizes, + gather_idxs); + + GKO_ASSERT_MTX_NEAR(gko::as(this->dist_mat->get_local_matrix()), + gko::as(dist_mat_b->get_local_matrix()), 0); + GKO_ASSERT_MTX_NEAR(gko::as(this->dist_mat->get_non_local_matrix()), + gko::as(dist_mat_b->get_local_matrix()), 0); +} + + TYPED_TEST(MatrixCreation, BuildOnlyLocal) { using value_type = typename TestFixture::value_type;