diff --git a/core/distributed/helpers.hpp b/core/distributed/helpers.hpp index 32e7f9510eb..9edf8282ed9 100644 --- a/core/distributed/helpers.hpp +++ b/core/distributed/helpers.hpp @@ -10,10 +10,14 @@ #include +#include #include #include +#include "core/base/dispatch_helper.hpp" + + namespace gko { namespace detail { @@ -140,6 +144,32 @@ void vector_dispatch(T* linop, F&& f, Args&&... args) } +#if GINKGO_BUILD_MPI + + +/** + * Specialization of run for distributed matrices. + */ +template +auto run_matrix(T* linop, F&& f, Args&&... args) +{ + using namespace gko::experimental::distributed; + return run, Matrix, + Matrix, Matrix, + Matrix, Matrix, + Matrix, int32, int32>, + Matrix, int32, int64>, + Matrix, int64, int64>, + Matrix, int32, int32>, + Matrix, int32, int64>, + Matrix, int64, int64>>( + linop, std::forward(f), std::forward(args)...); +} + + +#endif + + /** * Helper to extract a submatrix. * diff --git a/core/solver/multigrid.cpp b/core/solver/multigrid.cpp index 2f0944c0030..80c870fcc49 100644 --- a/core/solver/multigrid.cpp +++ b/core/solver/multigrid.cpp @@ -317,9 +317,14 @@ void MultigridState::generate(const LinOp* system_matrix_in, auto current_comm = distributed_fine->get_communicator(); auto next_comm = distributed_coarse->get_communicator(); auto current_local_nrows = - distributed_fine->get_local_size()[0]; + ::gko::detail::run_matrix(fine, [](auto* fine_mat) { + return fine_mat->get_local_matrix()->get_size()[0]; + }); auto next_local_nrows = - distributed_coarse->get_local_size()[0]; + ::gko::detail::run_matrix(coarse, [](auto* coarse_mat) { + return coarse_mat->get_non_local_matrix() + ->get_size()[0]; + }); this->allocate_memory( i, cycle, current_comm, next_comm, current_nrows, next_nrows, current_local_nrows, next_local_nrows);