Skip to content

Commit

Permalink
[config] adds minimal stopping criteria specification
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed May 17, 2024
1 parent 533156b commit 8356117
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 6 deletions.
88 changes: 84 additions & 4 deletions core/config/config_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
const type_descriptor& td)
{
deferred_factory_parameter<const stop::CriterionFactory> ptr;
if (config.get_tag() == pnode::tag_t::string) {
return detail::registry_accessor::get_data<stop::CriterionFactory>(
context, config.get_string());
} else if (config.get_tag() == pnode::tag_t::map) {
}

if (config.get_tag() == pnode::tag_t::map) {
static std::map<std::string,
std::function<deferred_factory_parameter<
gko::stop::CriterionFactory>(
Expand All @@ -62,8 +63,87 @@ parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
return criterion_map.at(config.get("type").get_string())(config,
context, td);
}
GKO_THROW_IF_INVALID(!ptr.is_empty(), "parse returned nullptr");
return ptr;

GKO_INVALID_STATE(
"Criteria must either be defined as a string or an array.");
}


std::vector<deferred_factory_parameter<const stop::CriterionFactory>>
parse_minimal_criteria(const pnode& config, const registry& context,
const type_descriptor& td)
{
auto map_time = [](const pnode& config, const registry& context,
const type_descriptor& td) {
pnode time_config{{{"time_limit", config.get("time")}}};
return configure_time(time_config, context, td);
};
auto map_iteration = [](const pnode& config, const registry& context,
const type_descriptor& td) {
pnode iter_config{{{"max_iters", config.get("iteration")}}};
return configure_iter(iter_config, context, td);
};
auto create_residual_mapping = [](const std::string& key,
const std::string& baseline,
auto configure_fn) {
return std::make_pair(
key, [=](const pnode& config, const registry& context,
const type_descriptor& td) {
pnode res_config{{{"baseline", pnode{baseline}},
{"reduction_factor", config.get(key)}}};
return configure_fn(res_config, context, td);
});
};
std::map<
std::string,
std::function<deferred_factory_parameter<gko::stop::CriterionFactory>(
const pnode&, const registry&, type_descriptor)>>
criterion_map{
{{"time", map_time},
{"iteration", map_iteration},
create_residual_mapping("relative_residual_norm", "rhs_norm",
configure_residual),
create_residual_mapping("initial_residual_norm", "initial_resnorm",
configure_residual),
create_residual_mapping("absolute_residual_norm", "absolute",
configure_residual),
create_residual_mapping("relative_implicit_residual_norm",
"rhs_norm", configure_implicit_residual),
create_residual_mapping("initial_implicit_residual_norm",
"initial_resnorm",
configure_implicit_residual),
create_residual_mapping("absolute_implicit_residual_norm",
"absolute", configure_implicit_residual)}};

std::vector<deferred_factory_parameter<const stop::CriterionFactory>> res;
for (const auto& it : config.get_map()) {
res.emplace_back(criterion_map.at(it.first)(config, context, td));
}
return res;
}


std::vector<deferred_factory_parameter<const stop::CriterionFactory>>
parse_or_get_criteria(const pnode& config, const registry& context,
const type_descriptor& td)
{
if (config.get_tag() == pnode::tag_t::array) {
return parse_or_get_factory_vector<const stop::CriterionFactory>(
config, context, td);
}

if (config.get_tag() == pnode::tag_t::map) {
return parse_minimal_criteria(config, context, td);
}

if (config.get_tag() == pnode::tag_t::string) {
return {detail::registry_accessor::get_data<stop::CriterionFactory>(
context, config.get_string())};
}

GKO_INVALID_STATE(
"Criteria must either be defined as a string, an array,"
"or an map.");
}

} // namespace config
Expand Down
9 changes: 9 additions & 0 deletions core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
const type_descriptor& td);

/**
* parse or get an std::vector of criteria.
* A stored single criterion will be converted to an std::vector.
*/
std::vector<deferred_factory_parameter<const stop::CriterionFactory>>
parse_or_get_criteria(const pnode& config, const registry& context,
const type_descriptor& td);


/**
* give a vector of factory by calling parse_or_get_factory.
*/
Expand Down
3 changes: 1 addition & 2 deletions core/solver/cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ typename Cg<ValueType>::parameters_type Cg<ValueType>::parse(
}
if (auto& obj = config.get("criteria")) {
params.with_criteria(
gko::config::parse_or_get_factory_vector<
const stop::CriterionFactory>(obj, context, td_for_child));
gko::config::parse_or_get_criteria(obj, context, td_for_child));
}
if (auto& obj = config.get("preconditioner")) {
params.with_preconditioner(
Expand Down

0 comments on commit 8356117

Please sign in to comment.