Skip to content

Commit

Permalink
[dist-mat] add run specialization for distributed matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed May 8, 2024
1 parent 2cdef8f commit 0c725a9
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
30 changes: 30 additions & 0 deletions core/distributed/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@


#include <ginkgo/config.hpp>
#include <ginkgo/core/distributed/matrix.hpp>
#include <ginkgo/core/distributed/vector.hpp>
#include <ginkgo/core/matrix/dense.hpp>


#include "core/base/dispatch_helper.hpp"


namespace gko {
namespace detail {

Expand Down Expand Up @@ -140,6 +144,32 @@ void vector_dispatch(T* linop, F&& f, Args&&... args)
}


#if GINKGO_BUILD_MPI


/**
* Specialization of run for distributed matrices.
*/
template <typename T, typename F, typename... Args>
auto run_matrix(T* linop, F&& f, Args&&... args)
{
using namespace gko::experimental::distributed;
return run<Matrix<double, int32, int32>, Matrix<double, int32, int64>,
Matrix<double, int64, int64>, Matrix<float, int32, int32>,
Matrix<float, int32, int64>, Matrix<float, int64, int64>,
Matrix<std::complex<double>, int32, int32>,
Matrix<std::complex<double>, int32, int64>,
Matrix<std::complex<double>, int64, int64>,
Matrix<std::complex<float>, int32, int32>,
Matrix<std::complex<float>, int32, int64>,
Matrix<std::complex<float>, int64, int64>>(
linop, std::forward<F>(f), std::forward<Args>(args)...);
}


#endif


/**
* Helper to extract a submatrix.
*
Expand Down
9 changes: 7 additions & 2 deletions core/solver/multigrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,14 @@ void MultigridState::generate(const LinOp* system_matrix_in,
auto current_comm = distributed_fine->get_communicator();
auto next_comm = distributed_coarse->get_communicator();
auto current_local_nrows =
distributed_fine->get_local_size()[0];
::gko::detail::run_matrix(fine, [](auto* fine_mat) {
return fine_mat->get_local_matrix()->get_size()[0];
});
auto next_local_nrows =
distributed_coarse->get_local_size()[0];
::gko::detail::run_matrix(coarse, [](auto* coarse_mat) {
return coarse_mat->get_non_local_matrix()
->get_size()[0];
});
this->allocate_memory<VectorType>(
i, cycle, current_comm, next_comm, current_nrows,
next_nrows, current_local_nrows, next_local_nrows);
Expand Down

0 comments on commit 0c725a9

Please sign in to comment.