diff --git a/core/config/config_helper.cpp b/core/config/config_helper.cpp index 2c70fa6f79e..8239f7da491 100644 --- a/core/config/config_helper.cpp +++ b/core/config/config_helper.cpp @@ -45,11 +45,12 @@ parse_or_get_factory(const pnode& config, const registry& context, const type_descriptor& td) { - deferred_factory_parameter ptr; if (config.get_tag() == pnode::tag_t::string) { return detail::registry_accessor::get_data( context, config.get_string()); - } else if (config.get_tag() == pnode::tag_t::map) { + } + + if (config.get_tag() == pnode::tag_t::map) { static std::map( @@ -62,8 +63,87 @@ parse_or_get_factory(const pnode& config, return criterion_map.at(config.get("type").get_string())(config, context, td); } - GKO_THROW_IF_INVALID(!ptr.is_empty(), "parse returned nullptr"); - return ptr; + + GKO_INVALID_STATE( + "Criteria must either be defined as a string or an array."); +} + + +std::vector> +parse_minimal_criteria(const pnode& config, const registry& context, + const type_descriptor& td) +{ + auto map_time = [](const pnode& config, const registry& context, + const type_descriptor& td) { + pnode time_config{{{"time_limit", config.get("time")}}}; + return configure_time(time_config, context, td); + }; + auto map_iteration = [](const pnode& config, const registry& context, + const type_descriptor& td) { + pnode iter_config{{{"max_iters", config.get("iteration")}}}; + return configure_iter(iter_config, context, td); + }; + auto create_residual_mapping = [](const std::string& key, + const std::string& baseline, + auto configure_fn) { + return std::make_pair( + key, [=](const pnode& config, const registry& context, + const type_descriptor& td) { + pnode res_config{{{"baseline", pnode{baseline}}, + {"reduction_factor", config.get(key)}}}; + return configure_fn(res_config, context, td); + }); + }; + std::map< + std::string, + std::function( + const pnode&, const registry&, type_descriptor)>> + criterion_map{ + {{"time", map_time}, + {"iteration", map_iteration}, + create_residual_mapping("relative_residual_norm", "rhs_norm", + configure_residual), + create_residual_mapping("initial_residual_norm", "initial_resnorm", + configure_residual), + create_residual_mapping("absolute_residual_norm", "absolute", + configure_residual), + create_residual_mapping("relative_implicit_residual_norm", + "rhs_norm", configure_implicit_residual), + create_residual_mapping("initial_implicit_residual_norm", + "initial_resnorm", + configure_implicit_residual), + create_residual_mapping("absolute_implicit_residual_norm", + "absolute", configure_implicit_residual)}}; + + std::vector> res; + for (const auto& it : config.get_map()) { + res.emplace_back(criterion_map.at(it.first)(config, context, td)); + } + return res; +} + + +std::vector> +parse_or_get_criteria(const pnode& config, const registry& context, + const type_descriptor& td) +{ + if (config.get_tag() == pnode::tag_t::array) { + return parse_or_get_factory_vector( + config, context, td); + } + + if (config.get_tag() == pnode::tag_t::map) { + return parse_minimal_criteria(config, context, td); + } + + if (config.get_tag() == pnode::tag_t::string) { + return {detail::registry_accessor::get_data( + context, config.get_string())}; + } + + GKO_INVALID_STATE( + "Criteria must either be defined as a string, an array," + "or an map."); } } // namespace config diff --git a/core/config/config_helper.hpp b/core/config/config_helper.hpp index 798d3623856..3117a0c7cd6 100644 --- a/core/config/config_helper.hpp +++ b/core/config/config_helper.hpp @@ -86,6 +86,15 @@ parse_or_get_factory(const pnode& config, const registry& context, const type_descriptor& td); +/** + * parse or get an std::vector of criteria. + * A stored single criterion will be converted to an std::vector. + */ +std::vector> +parse_or_get_criteria(const pnode& config, const registry& context, + const type_descriptor& td); + + /** * give a vector of factory by calling parse_or_get_factory. */ diff --git a/core/solver/cg.cpp b/core/solver/cg.cpp index 71e5fcfbb3b..fe5450be070 100644 --- a/core/solver/cg.cpp +++ b/core/solver/cg.cpp @@ -69,8 +69,7 @@ typename Cg::parameters_type Cg::parse( } if (auto& obj = config.get("criteria")) { params.with_criteria( - gko::config::parse_or_get_factory_vector< - const stop::CriterionFactory>(obj, context, td_for_child)); + gko::config::parse_or_get_criteria(obj, context, td_for_child)); } if (auto& obj = config.get("preconditioner")) { params.with_preconditioner(