From 24c6f7dbca325d5fbda32a521029c688cc08affc Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Mon, 12 Aug 2024 16:43:26 +0200 Subject: [PATCH] Move to deferred factory, fix issues --- core/distributed/preconditioner/schwarz.cpp | 55 +++++++++++-------- .../distributed/preconditioner/schwarz.cpp | 36 +++++------- .../distributed-solver/distributed-solver.cpp | 11 ++-- .../distributed/preconditioner/schwarz.hpp | 11 ++-- test/mpi/preconditioner/schwarz.cpp | 28 ++++++++++ 5 files changed, 84 insertions(+), 57 deletions(-) diff --git a/core/distributed/preconditioner/schwarz.cpp b/core/distributed/preconditioner/schwarz.cpp index abde203895c..4bd960f5b22 100644 --- a/core/distributed/preconditioner/schwarz.cpp +++ b/core/distributed/preconditioner/schwarz.cpp @@ -53,8 +53,11 @@ void Schwarz::apply_dense_impl( } if (this->coarse_solver_ != nullptr && this->galerkin_ops_ != nullptr) { - auto restrict = this->galerkin_ops_->get_restrict_op(); - auto prolong = this->galerkin_ops_->get_prolong_op(); + auto restrict = as(this->galerkin_ops_) + ->get_restrict_op(); + auto prolong = as(this->galerkin_ops_) + ->get_prolong_op(); + GKO_ASSERT(this->half_ != nullptr); restrict->apply(dense_b, this->csol_); this->coarse_solver_->apply(this->csol_, this->csol_); @@ -119,28 +122,32 @@ void Schwarz::generate( } - if (parameters_.galerkin_ops_factory && parameters_.coarse_solver_factory) { - this->galerkin_ops_ = as( - share(parameters_.galerkin_ops_factory->generate(dist_mat))); - auto coarse = - as>( - this->galerkin_ops_->get_coarse_op()); - auto exec = coarse->get_executor(); - auto comm = coarse->get_communicator(); - this->coarse_solver_ = - parameters_.coarse_solver_factory->generate(coarse); - // TODO: Set correct rhs and stride. - auto cs_ncols = 1; // dense_x->get_size()[1]; - auto cs_local_nrows = coarse->get_local_matrix()->get_size()[0]; - auto cs_global_nrows = coarse->get_size()[0]; - auto cs_local_size = dim<2>(cs_local_nrows, cs_ncols); - auto cs_global_size = dim<2>(cs_global_nrows, cs_ncols); - this->csol_ = gko::share(dist_vec::create(exec, comm, cs_global_size, - cs_local_size, - 1 /*dense_x->get_stride()*/)); - // this->temp_ = this->csol->clone(); - this->half_ = gko::share(gko::initialize({0.5}, exec)); + if (parameters_.galerkin_ops && parameters_.coarse_solver) { + this->galerkin_ops_ = + share(parameters_.galerkin_ops->generate(system_matrix)); + if (as(this->galerkin_ops_) + ->get_coarse_op()) { + auto coarse = + as>( + as(this->galerkin_ops_) + ->get_coarse_op()); + auto exec = coarse->get_executor(); + auto comm = coarse->get_communicator(); + this->coarse_solver_ = + share(parameters_.coarse_solver->generate(coarse)); + // TODO: Set correct rhs and stride. + auto cs_ncols = 1; // dense_x->get_size()[1]; + auto cs_local_nrows = coarse->get_local_matrix()->get_size()[0]; + auto cs_global_nrows = coarse->get_size()[0]; + auto cs_local_size = dim<2>(cs_local_nrows, cs_ncols); + auto cs_global_size = dim<2>(cs_global_nrows, cs_ncols); + this->csol_ = gko::share( + dist_vec::create(exec, comm, cs_global_size, cs_local_size, + 1 /*dense_x->get_stride()*/)); + // this->temp_ = this->csol->clone(); + this->half_ = gko::share(gko::initialize({0.5}, exec)); + } } } diff --git a/core/test/mpi/distributed/preconditioner/schwarz.cpp b/core/test/mpi/distributed/preconditioner/schwarz.cpp index 4668a456f4a..921185f7224 100644 --- a/core/test/mpi/distributed/preconditioner/schwarz.cpp +++ b/core/test/mpi/distributed/preconditioner/schwarz.cpp @@ -14,9 +14,6 @@ #include "core/test/utils.hpp" -namespace { - - template class SchwarzFactory : public ::testing::Test { protected: @@ -44,8 +41,8 @@ class SchwarzFactory : public ::testing::Test { { schwarz = Schwarz::build() .with_local_solver(jacobi_factory) - .with_galerkin_ops_factory(pgm_factory) - .with_coarse_solver_factory(pgm_factory) + .with_galerkin_ops(pgm_factory) + .with_coarse_solver(cg_factory) .on(exec) ->generate(mtx); } @@ -63,10 +60,10 @@ class SchwarzFactory : public ::testing::Test { ASSERT_EQ(a->get_size(), b->get_size()); ASSERT_EQ(a->get_parameters().local_solver, b->get_parameters().local_solver); - ASSERT_EQ(a->get_parameters().galerkin_ops_factory, - b->get_parameters().galerkin_ops_factory); - ASSERT_EQ(a->get_parameters().coarse_solver_factory, - b->get_parameters().coarse_solver_factory); + ASSERT_EQ(a->get_parameters().galerkin_ops, + b->get_parameters().galerkin_ops); + ASSERT_EQ(a->get_parameters().coarse_solver, + b->get_parameters().coarse_solver); } std::shared_ptr exec; @@ -96,15 +93,13 @@ TYPED_TEST(SchwarzFactory, CanSetLocalFactory) TYPED_TEST(SchwarzFactory, CanSetGalerkinOpsFactory) { - ASSERT_EQ(this->schwarz->get_parameters().galerkin_ops_factory, - this->pgm_factory); + ASSERT_EQ(this->schwarz->get_parameters().galerkin_ops, this->pgm_factory); } TYPED_TEST(SchwarzFactory, CanSetCoarseSolverFactory) { - ASSERT_EQ(this->schwarz->get_parameters().coarse_solver_factory, - this->cg_factory); + ASSERT_EQ(this->schwarz->get_parameters().coarse_solver, this->cg_factory); } @@ -128,8 +123,8 @@ TYPED_TEST(SchwarzFactory, CanBeCopied) auto cg = gko::share(Cg::build().on(this->exec)); auto copy = Schwarz::build() .with_local_solver(bj) - .with_galerkin_ops_factory(pgm) - .with_coarse_solver_factory(cg) + .with_galerkin_ops(pgm) + .with_coarse_solver(cg) .on(this->exec) ->generate(Mtx::create(this->exec, MPI_COMM_WORLD)); @@ -152,8 +147,8 @@ TYPED_TEST(SchwarzFactory, CanBeMoved) auto cg = gko::share(Cg::build().on(this->exec)); auto copy = Schwarz::build() .with_local_solver(bj) - .with_galerkin_ops_factory(pgm) - .with_coarse_solver_factory(cg) + .with_galerkin_ops(pgm) + .with_coarse_solver(cg) .on(this->exec) ->generate(Mtx::create(this->exec, MPI_COMM_WORLD)); @@ -169,8 +164,8 @@ TYPED_TEST(SchwarzFactory, CanBeCleared) ASSERT_EQ(this->schwarz->get_size(), gko::dim<2>(0, 0)); ASSERT_EQ(this->schwarz->get_parameters().local_solver, nullptr); - ASSERT_EQ(this->schwarz->get_parameters().galerkin_ops_factory, nullptr); - ASSERT_EQ(this->schwarz->get_parameters().coarse_solver_factory, nullptr); + ASSERT_EQ(this->schwarz->get_parameters().galerkin_ops, nullptr); + ASSERT_EQ(this->schwarz->get_parameters().coarse_solver, nullptr); } @@ -185,6 +180,3 @@ TYPED_TEST(SchwarzFactory, PassExplicitFactory) ASSERT_EQ(factory->get_parameters().local_solver, jacobi_factory); } - - -} // namespace diff --git a/examples/distributed-solver/distributed-solver.cpp b/examples/distributed-solver/distributed-solver.cpp index acda7a52f8c..427c8a97f4a 100644 --- a/examples/distributed-solver/distributed-solver.cpp +++ b/examples/distributed-solver/distributed-solver.cpp @@ -209,12 +209,11 @@ int main(int argc, char* argv[]) if (schw_type == "multi-level") { Ainv = solver::build() - .with_preconditioner( - schwarz::build() - .with_local_solver(local_solver) - .with_galerkin_ops_factory(pgm_fac) - .with_coarse_solver_factory(coarse_solver) - .on(exec)) + .with_preconditioner(schwarz::build() + .with_local_solver(local_solver) + .with_galerkin_ops(pgm_fac) + .with_coarse_solver(coarse_solver) + .on(exec)) .with_criteria( gko::stop::Iteration::build().with_max_iters(num_iters).on( exec), diff --git a/include/ginkgo/core/distributed/preconditioner/schwarz.hpp b/include/ginkgo/core/distributed/preconditioner/schwarz.hpp index d5c6952dc7d..274e053d707 100644 --- a/include/ginkgo/core/distributed/preconditioner/schwarz.hpp +++ b/include/ginkgo/core/distributed/preconditioner/schwarz.hpp @@ -17,6 +17,7 @@ #include #include #include +#include namespace gko { @@ -80,14 +81,14 @@ class Schwarz * Operator factory to generate the triplet (prolong_op, coarse_op, * restrict_op). */ - std::shared_ptr GKO_FACTORY_PARAMETER_SCALAR( - galerkin_ops_factory, nullptr); + std::shared_ptr GKO_DEFERRED_FACTORY_PARAMETER( + galerkin_ops); /** * Coarse solver factory. */ - std::shared_ptr GKO_FACTORY_PARAMETER_SCALAR( - coarse_solver_factory, nullptr); + std::shared_ptr GKO_DEFERRED_FACTORY_PARAMETER( + coarse_solver); }; GKO_ENABLE_LIN_OP_FACTORY(Schwarz, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); @@ -141,7 +142,7 @@ class Schwarz void set_solver(std::shared_ptr new_solver); std::shared_ptr local_solver_; - std::shared_ptr galerkin_ops_; + std::shared_ptr galerkin_ops_; std::shared_ptr coarse_solver_; std::shared_ptr csol_; std::shared_ptr half_; diff --git a/test/mpi/preconditioner/schwarz.cpp b/test/mpi/preconditioner/schwarz.cpp index 6717cd9d888..bd66197bfb7 100644 --- a/test/mpi/preconditioner/schwarz.cpp +++ b/test/mpi/preconditioner/schwarz.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -59,6 +60,9 @@ class SchwarzPreconditioner : public CommonMpiTestFixture { using solver_type = gko::solver::Bicgstab; using local_prec_type = gko::preconditioner::Jacobi; + using coarse_solver_type = + gko::preconditioner::Jacobi; + using galerkin_ops_type = gko::multigrid::Pgm; using local_matrix_type = gko::matrix::Csr; using non_dist_matrix_type = gko::matrix::Csr; @@ -125,6 +129,8 @@ class SchwarzPreconditioner : public CommonMpiTestFixture { std::shared_ptr non_dist_solver_factory; std::shared_ptr dist_solver_factory; std::shared_ptr local_solver_factory; + std::shared_ptr pgm_factory; + std::shared_ptr coarse_solver_factory; void assert_equal_to_non_distributed_vector( std::shared_ptr dist_vec, @@ -271,6 +277,28 @@ TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditioner) } +TYPED_TEST(SchwarzPreconditioner, CanApplyMultilevelPreconditioner) +{ + using value_type = typename TestFixture::value_type; + using prec = typename TestFixture::dist_prec_type; + + auto precond_factory = prec::build() + .with_local_solver(this->local_solver_factory) + .with_coarse_solver(this->coarse_solver_factory) + .with_galerkin_ops(this->pgm_factory) + .on(this->exec); + auto local_precond = + this->local_solver_factory->generate(this->non_dist_mat); + auto precond = precond_factory->generate(this->dist_mat); + + precond->apply(this->dist_b.get(), this->dist_x.get()); + local_precond->apply(this->non_dist_b.get(), this->non_dist_x.get()); + + this->assert_equal_to_non_distributed_vector(this->dist_x, + this->non_dist_x); +} + + TYPED_TEST(SchwarzPreconditioner, CanAdvancedApplyPreconditioner) { using value_type = typename TestFixture::value_type;