From 38361297209ee6984d02f2996ee9d75cd3c90d9b Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Tue, 14 May 2024 23:03:17 +0200 Subject: [PATCH] use more general default selector --- core/solver/multigrid.cpp | 16 ++++++---------- include/ginkgo/core/solver/multigrid.hpp | 5 ++--- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/core/solver/multigrid.cpp b/core/solver/multigrid.cpp index 2f0944c0030..eb3b7d9ebae 100644 --- a/core/solver/multigrid.cpp +++ b/core/solver/multigrid.cpp @@ -875,16 +875,12 @@ Multigrid::Multigrid(const Multigrid::Factory* factory, stop::combine(factory->get_parameters().criteria)}, parameters_{factory->get_parameters()} { + this->validate(); if (!parameters_.level_selector) { - if (parameters_.mg_level.size() == 1) { - level_selector_ = [](const size_type, const LinOp*) { - return size_type{0}; - }; - } else if (parameters_.mg_level.size() > 1) { - level_selector_ = [](const size_type level, const LinOp*) { - return level; - }; - } + auto num = parameters_.mg_level.size(); + level_selector_ = [num](const size_type level, const LinOp*) { + return (level < num) ? level : num - 1; + }; } else { level_selector_ = parameters_.level_selector; } @@ -898,7 +894,7 @@ Multigrid::Multigrid(const Multigrid::Factory* factory, solver_selector_ = parameters_.solver_selector; } - this->validate(); + this->set_default_initial_guess(parameters_.default_initial_guess); if (this->get_system_matrix()->get_size()[0] != 0) { // generate on the existed matrix diff --git a/include/ginkgo/core/solver/multigrid.hpp b/include/ginkgo/core/solver/multigrid.hpp index 9646e2779d7..7ca38bff661 100644 --- a/include/ginkgo/core/solver/multigrid.hpp +++ b/include/ginkgo/core/solver/multigrid.hpp @@ -222,9 +222,8 @@ class Multigrid : public EnableLinOp, * >= 3 and the number of rows of fine matrix > 1024, or the 2-idx * elements otherwise. * - * default selector: - * use the first factory when mg_level size = 1 - * use the level as the index when mg_level size > 1 + * default selector: use the level as the index when the level < + * #mg_level and reuse the last one when the level >= #mg_level */ std::function GKO_FACTORY_PARAMETER_SCALAR(level_selector, nullptr);