Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solver config #1395

Merged
merged 8 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ target_sources(ginkgo
config/config_helper.cpp
config/property_tree.cpp
config/registry.cpp
config/solver_config.cpp
config/stop_config.cpp
config/type_descriptor.cpp
distributed/index_map.cpp
Expand Down
57 changes: 55 additions & 2 deletions core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,21 @@ namespace config {
* LinOpFactoryType enum is to avoid forward declaration, linopfactory header,
* two template versions of parse
*/
enum class LinOpFactoryType : int { Cg = 0 };
enum class LinOpFactoryType : int {
Cg = 0,
Bicg,
Bicgstab,
Fcg,
Cgs,
Ir,
Idr,
Gcr,
Gmres,
CbGmres,
Direct,
LowerTrs,
UpperTrs
};


/**
Expand Down Expand Up @@ -107,13 +121,29 @@ inline std::vector<deferred_factory_parameter<T>> parse_or_get_factory_vector(
}


/**
* get_value gets the corresponding type value from config.
*
* This is specialization for bool type
*/
template <typename ValueType>
inline std::enable_if_t<std::is_same<ValueType, bool>::value, bool> get_value(
const pnode& config)
{
auto val = config.get_boolean();
return val;
}


/**
* get_value gets the corresponding type value from config.
*
* This is specialization for integral type
*/
template <typename IndexType>
inline std::enable_if_t<std::is_integral<IndexType>::value, IndexType>
inline std::enable_if_t<std::is_integral<IndexType>::value &&
!std::is_same<IndexType, bool>::value,
IndexType>
get_value(const pnode& config)
{
auto val = config.get_integer();
Expand Down Expand Up @@ -173,6 +203,29 @@ get_value(const pnode& config)
}


/**
* get_value gets the corresponding type value from config.
*
* This is specialization for initial_guess_mode
*/
template <typename ValueType>
inline std::enable_if_t<
std::is_same<ValueType, solver::initial_guess_mode>::value,
solver::initial_guess_mode>
get_value(const pnode& config)
{
auto val = config.get_string();
if (val == "zero") {
return solver::initial_guess_mode::zero;
} else if (val == "rhs") {
return solver::initial_guess_mode::rhs;
} else if (val == "provided") {
return solver::initial_guess_mode::provided;
}
GKO_INVALID_STATE("Wrong value for initial_guess_mode");
}


} // namespace config
} // namespace gko

Expand Down
2 changes: 2 additions & 0 deletions core/config/dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/solver/solver_base.hpp>
Expand Down Expand Up @@ -106,6 +107,7 @@ deferred_factory_parameter<ReturnType> dispatch(
using value_type_list =
syn::type_list<double, float, std::complex<double>, std::complex<float>>;

using index_type_list = syn::type_list<int32, int64>;

} // namespace config
} // namespace gko
Expand Down
61 changes: 61 additions & 0 deletions core/config/parse_macro.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_CONFIG_PARSE_MACRO_HPP_
#define GKO_CORE_CONFIG_PARSE_MACRO_HPP_


#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/config/type_descriptor.hpp>


#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"
#include "core/config/type_descriptor_helper.hpp"


// for value_type only
#define GKO_PARSE_VALUE_TYPE(_type, _configurator) \
template <> \
deferred_factory_parameter<gko::LinOpFactory> \
parse<gko::config::LinOpFactoryType::_type>( \
const gko::config::pnode& config, \
const gko::config::registry& context, \
const gko::config::type_descriptor& td) \
{ \
auto updated = gko::config::update_type(config, td); \
return gko::config::dispatch<gko::LinOpFactory, _configurator>( \
config, context, updated, \
gko::config::make_type_selector(updated.get_value_typestr(), \
gko::config::value_type_list())); \
} \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")


// for value_type and index_type
#define GKO_PARSE_VALUE_AND_INDEX_TYPE(_type, _configurator) \
template <> \
deferred_factory_parameter<gko::LinOpFactory> \
parse<gko::config::LinOpFactoryType::_type>( \
const gko::config::pnode& config, \
const gko::config::registry& context, \
const gko::config::type_descriptor& td) \
{ \
auto updated = gko::config::update_type(config, td); \
return gko::config::dispatch<gko::LinOpFactory, _configurator>( \
config, context, updated, \
gko::config::make_type_selector(updated.get_value_typestr(), \
gko::config::value_type_list()), \
gko::config::make_type_selector(updated.get_index_typestr(), \
gko::config::index_type_list())); \
} \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")


#endif // GKO_CORE_CONFIG_PARSE_MACRO_HPP_
14 changes: 13 additions & 1 deletion core/config/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,19 @@ namespace config {

configuration_map generate_config_map()
{
return {{"solver::Cg", parse<LinOpFactoryType::Cg>}};
return {{"solver::Cg", parse<LinOpFactoryType::Cg>},
{"solver::Bicg", parse<LinOpFactoryType::Bicg>},
{"solver::Bicgstab", parse<LinOpFactoryType::Bicgstab>},
{"solver::Fcg", parse<LinOpFactoryType::Fcg>},
{"solver::Cgs", parse<LinOpFactoryType::Cgs>},
{"solver::Ir", parse<LinOpFactoryType::Ir>},
{"solver::Idr", parse<LinOpFactoryType::Idr>},
{"solver::Gcr", parse<LinOpFactoryType::Gcr>},
{"solver::Gmres", parse<LinOpFactoryType::Gmres>},
{"solver::CbGmres", parse<LinOpFactoryType::CbGmres>},
{"solver::Direct", parse<LinOpFactoryType::Direct>},
{"solver::LowerTrs", parse<LinOpFactoryType::LowerTrs>},
{"solver::UpperTrs", parse<LinOpFactoryType::UpperTrs>}};
}


Expand Down
48 changes: 48 additions & 0 deletions core/config/solver_config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/solver/bicg.hpp>
#include <ginkgo/core/solver/bicgstab.hpp>
#include <ginkgo/core/solver/cb_gmres.hpp>
#include <ginkgo/core/solver/cg.hpp>
#include <ginkgo/core/solver/cgs.hpp>
#include <ginkgo/core/solver/direct.hpp>
#include <ginkgo/core/solver/fcg.hpp>
#include <ginkgo/core/solver/gcr.hpp>
#include <ginkgo/core/solver/gmres.hpp>
#include <ginkgo/core/solver/idr.hpp>
#include <ginkgo/core/solver/ir.hpp>
#include <ginkgo/core/solver/triangular.hpp>


#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"
#include "core/config/parse_macro.hpp"
#include "core/config/solver_config.hpp"


namespace gko {
namespace config {


GKO_PARSE_VALUE_TYPE(Cg, gko::solver::Cg);
GKO_PARSE_VALUE_TYPE(Bicg, gko::solver::Bicg);
GKO_PARSE_VALUE_TYPE(Bicgstab, gko::solver::Bicgstab);
GKO_PARSE_VALUE_TYPE(Cgs, gko::solver::Cgs);
GKO_PARSE_VALUE_TYPE(Fcg, gko::solver::Fcg);
GKO_PARSE_VALUE_TYPE(Ir, gko::solver::Ir);
GKO_PARSE_VALUE_TYPE(Idr, gko::solver::Idr);
GKO_PARSE_VALUE_TYPE(Gcr, gko::solver::Gcr);
GKO_PARSE_VALUE_TYPE(Gmres, gko::solver::Gmres);
GKO_PARSE_VALUE_TYPE(CbGmres, gko::solver::CbGmres);
GKO_PARSE_VALUE_AND_INDEX_TYPE(Direct, gko::experimental::solver::Direct);
GKO_PARSE_VALUE_AND_INDEX_TYPE(LowerTrs, gko::solver::LowerTrs);
GKO_PARSE_VALUE_AND_INDEX_TYPE(UpperTrs, gko::solver::UpperTrs);


} // namespace config
} // namespace gko
45 changes: 45 additions & 0 deletions core/config/solver_config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_CONFIG_SOLVER_CONFIG_HPP_
#define GKO_CORE_CONFIG_SOLVER_CONFIG_HPP_


#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>


#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"

namespace gko {
namespace config {


template <typename SolverParam>
inline void common_solver_parse(SolverParam& params, const pnode& config,
const registry& context,
type_descriptor td_for_child)
{
if (auto& obj = config.get("generated_preconditioner")) {
params.with_generated_preconditioner(
gko::config::get_stored_obj<const LinOp>(obj, context));
}
if (auto& obj = config.get("criteria")) {
params.with_criteria(
gko::config::parse_or_get_factory_vector<
const stop::CriterionFactory>(obj, context, td_for_child));
}
if (auto& obj = config.get("preconditioner")) {
params.with_preconditioner(
gko::config::parse_or_get_factory<const LinOpFactory>(
obj, context, td_for_child));
}
}


} // namespace config
} // namespace gko

#endif // GKO_CORE_CONFIG_SOLVER_CONFIG_HPP_
12 changes: 6 additions & 6 deletions core/config/stop_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@ namespace config {
deferred_factory_parameter<stop::CriterionFactory> configure_time(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto factory = stop::Time::build();
auto params = stop::Time::build();
if (auto& obj = config.get("time_limit")) {
factory.with_time_limit(gko::config::get_value<long long int>(obj));
params.with_time_limit(gko::config::get_value<long long int>(obj));
}
return factory;
return params;
}


deferred_factory_parameter<stop::CriterionFactory> configure_iter(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto factory = stop::Iteration::build();
auto params = stop::Iteration::build();
if (auto& obj = config.get("max_iters")) {
factory.with_max_iters(gko::config::get_value<size_type>(obj));
params.with_max_iters(gko::config::get_value<size_type>(obj));
}
return factory;
return params;
}


Expand Down
49 changes: 49 additions & 0 deletions core/config/trisolver_config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_CONFIG_TRISOLVER_CONFIG_HPP_
#define GKO_CORE_CONFIG_TRISOLVER_CONFIG_HPP_


#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/solver/triangular.hpp>


#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"

namespace gko {
namespace config {


template <typename SolverParam>
inline void common_trisolver_parse(SolverParam& params, const pnode& config,
const registry& context,
type_descriptor td_for_child)
{
if (auto& obj = config.get("num_rhs")) {
params.with_num_rhs(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("unit_diagonal")) {
params.with_unit_diagonal(gko::config::get_value<bool>(obj));
}
if (auto& obj = config.get("algorithm")) {
using gko::solver::trisolve_algorithm;
auto str = obj.get_string();
if (str == "sparselib") {
params.with_algorithm(trisolve_algorithm::sparselib);
} else if (str == "syncfree") {
params.with_algorithm(trisolve_algorithm::syncfree);
} else {
GKO_INVALID_STATE("Wrong value for algorithm");
}
}
}


} // namespace config
} // namespace gko

#endif // GKO_CORE_CONFIG_TRISOLVER_CONFIG_HPP_
12 changes: 12 additions & 0 deletions core/solver/bicg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ginkgo/core/base/precision_dispatch.hpp>


#include "core/config/solver_config.hpp"
#include "core/solver/bicg_kernels.hpp"
#include "core/solver/solver_boilerplate.hpp"

Expand All @@ -32,6 +33,17 @@ GKO_REGISTER_OPERATION(step_2, bicg::step_2);
} // namespace bicg


template <typename ValueType>
typename Bicg<ValueType>::parameters_type Bicg<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto params = solver::Bicg<ValueType>::build();
common_solver_parse(params, config, context, td_for_child);
return params;
}


template <typename ValueType>
std::unique_ptr<LinOp> Bicg<ValueType>::transpose() const
{
Expand Down
Loading
Loading