Skip to content

Commit

Permalink
adapt the corresponding changes and remove the macro
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Mar 28, 2024
1 parent eb7fb2a commit 2d5fed4
Show file tree
Hide file tree
Showing 12 changed files with 179 additions and 126 deletions.
4 changes: 2 additions & 2 deletions core/config/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ template <typename ValueType>
inline typename std::enable_if<std::is_same<ValueType, bool>::value, bool>::type
get_value(const pnode& config)
{
auto val = config.get_data<bool>();
auto val = config.get_boolean();
return val;
}

Expand Down Expand Up @@ -157,7 +157,7 @@ inline typename std::enable_if<
solver::initial_guess_mode>::type
get_value(const pnode& config)
{
auto val = config.get_data<std::string>();
auto val = config.get_string();
if (val == "zero") {
return solver::initial_guess_mode::zero;
} else if (val == "rhs") {
Expand Down
22 changes: 15 additions & 7 deletions core/config/solver_config.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017-2023 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -22,12 +22,20 @@ inline void common_solver_configure(SolverFactory& factory, const pnode& config,
const registry& context,
type_descriptor td_for_child)
{
SET_POINTER(factory, const LinOp, generated_preconditioner, config, context,
td_for_child);
SET_FACTORY_VECTOR(factory, const stop::CriterionFactory, criteria, config,
context, td_for_child);
SET_FACTORY(factory, const LinOpFactory, preconditioner, config, context,
td_for_child);
if (auto& obj = config.get("generated_preconditioner")) {
factory.with_generated_preconditioner(
gko::config::get_pointer<const LinOp>(obj, context, td_for_child));
}
if (auto& obj = config.get("criteria")) {
factory.with_criteria(
gko::config::get_factory_vector<const stop::CriterionFactory>(
obj, context, td_for_child));
}
if (auto& obj = config.get("preconditioner")) {
factory.with_preconditioner(
gko::config::get_factory<const LinOpFactory>(obj, context,
td_for_child));
}
}


Expand Down
9 changes: 5 additions & 4 deletions core/solver/cb_gmres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ CbGmres<ValueType>::build_from_config(const config::pnode& config,
{
auto factory = solver::CbGmres<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
SET_VALUE(factory, size_type, krylov_dim, config);
if (config.contains("storage_precision")) {
if (auto& obj = config.get("krylov_dim")) {
factory.with_krylov_dim(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("storage_precision")) {
auto get_storage_precision = [](std::string str) {
using gko::solver::cb_gmres::storage_precision;
if (str == "keep") {
Expand All @@ -186,8 +188,7 @@ CbGmres<ValueType>::build_from_config(const config::pnode& config,
}
GKO_INVALID_STATE("Wrong value for storage_precision");
};
factory.with_storage_precision(get_storage_precision(
config.at("storage_precision").get_data<std::string>()));
factory.with_storage_precision(get_storage_precision(obj.get_string()));
}
return factory;
}
Expand Down
10 changes: 7 additions & 3 deletions core/solver/direct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ Direct<ValueType, IndexType>::build_from_config(
config::type_descriptor td_for_child)
{
auto factory = Direct<ValueType, IndexType>::build();
SET_VALUE(factory, size_type, num_rhs, config);
SET_FACTORY(factory, const LinOpFactory, factorization, config, context,
td_for_child);
if (auto& obj = config.get("num_rhs")) {
factory.with_num_rhs(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("factorization")) {
factory.with_factorization(gko::config::get_factory<const LinOpFactory>(
obj, context, td_for_child));
}
return factory;
}

Expand Down
4 changes: 3 additions & 1 deletion core/solver/gcr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ typename Gcr<ValueType>::parameters_type Gcr<ValueType>::build_from_config(
{
auto factory = solver::Gcr<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
SET_VALUE(factory, size_type, krylov_dim, config);
if (auto& obj = config.get("krylov_dim")) {
factory.with_krylov_dim(gko::config::get_value<size_type>(obj));
}
return factory;
}

Expand Down
8 changes: 6 additions & 2 deletions core/solver/gmres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ typename Gmres<ValueType>::parameters_type Gmres<ValueType>::build_from_config(
{
auto factory = solver::Gmres<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
SET_VALUE(factory, size_type, krylov_dim, config);
SET_VALUE(factory, bool, flexible, config);
if (auto& obj = config.get("krylov_dim")) {
factory.with_krylov_dim(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("flexible")) {
factory.with_flexible(gko::config::get_value<bool>(obj));
}
return factory;
}

Expand Down
17 changes: 13 additions & 4 deletions core/solver/idr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,19 @@ typename Idr<ValueType>::parameters_type Idr<ValueType>::build_from_config(
{
auto factory = solver::Idr<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
SET_VALUE(factory, size_type, subspace_dim, config);
SET_VALUE(factory, remove_complex<ValueType>, kappa, config);
SET_VALUE(factory, bool, deterministic, config);
SET_VALUE(factory, bool, complex_subspace, config);
if (auto& obj = config.get("subspace_dim")) {
factory.with_subspace_dim(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("kappa")) {
factory.with_kappa(
gko::config::get_value<remove_complex<ValueType>>(obj));
}
if (auto& obj = config.get("deterministic")) {
factory.with_deterministic(gko::config::get_value<bool>(obj));
}
if (auto& obj = config.get("complex_subspace")) {
factory.with_complex_subspace(gko::config::get_value<bool>(obj));
}
return factory;
}

Expand Down
29 changes: 20 additions & 9 deletions core/solver/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,26 @@ typename Ir<ValueType>::parameters_type Ir<ValueType>::build_from_config(
config::type_descriptor td_for_child)
{
auto factory = solver::Ir<ValueType>::build();
SET_FACTORY_VECTOR(factory, const stop::CriterionFactory, criteria, config,
context, td_for_child);
SET_FACTORY(factory, const LinOpFactory, solver, config, context,
td_for_child);
SET_POINTER(factory, const LinOp, generated_solver, config, context,
td_for_child);
SET_VALUE(factory, ValueType, relaxation_factor, config);
SET_VALUE(factory, solver::initial_guess_mode, default_initial_guess,
config);
if (auto& obj = config.get("criteria")) {
factory.with_criteria(
gko::config::get_factory_vector<const stop::CriterionFactory>(
obj, context, td_for_child));
}
if (auto& obj = config.get("solver")) {
factory.with_solver(gko::config::get_factory<const LinOpFactory>(
obj, context, td_for_child));
}
if (auto& obj = config.get("generated_solver")) {
factory.with_generated_solver(
gko::config::get_pointer<const LinOp>(obj, context, td_for_child));
}
if (auto& obj = config.get("relaxation_factor")) {
factory.with_relaxation_factor(gko::config::get_value<ValueType>(obj));
}
if (auto& obj = config.get("default_initial_guess")) {
factory.with_default_initial_guess(
gko::config::get_value<solver::initial_guess_mode>(obj));
}
return factory;
}

Expand Down
12 changes: 8 additions & 4 deletions core/solver/lower_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@ LowerTrs<ValueType, IndexType>::build_from_config(
{
auto factory = LowerTrs<ValueType, IndexType>::build();
// duplicate?
SET_VALUE(factory, size_type, num_rhs, config);
SET_VALUE(factory, bool, unit_diagonal, config);
if (config.contains("algorithm")) {
if (auto& obj = config.get("num_rhs")) {
factory.with_num_rhs(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("unit_diagonal")) {
factory.with_unit_diagonal(gko::config::get_value<bool>(obj));
}
if (auto& obj = config.get("algorithm")) {
using gko::solver::trisolve_algorithm;
auto str = config.at("algorithm").get_data<std::string>();
auto str = obj.get_string();
if (str == "sparselib") {
factory.with_algorithm(trisolve_algorithm::sparselib);
} else if (str == "syncfree") {
Expand Down
12 changes: 8 additions & 4 deletions core/solver/upper_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@ UpperTrs<ValueType, IndexType>::build_from_config(
{
auto factory = UpperTrs<ValueType, IndexType>::build();
// duplicate?
SET_VALUE(factory, size_type, num_rhs, config);
SET_VALUE(factory, bool, unit_diagonal, config);
if (config.contains("algorithm")) {
if (auto& obj = config.get("num_rhs")) {
factory.with_num_rhs(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("unit_diagonal")) {
factory.with_unit_diagonal(gko::config::get_value<bool>(obj));
}
if (auto& obj = config.get("algorithm")) {
using gko::solver::trisolve_algorithm;
auto str = config.at("algorithm").get_data<std::string>();
auto str = obj.get_string();
if (str == "sparselib") {
factory.with_algorithm(trisolve_algorithm::sparselib);
} else if (str == "syncfree") {
Expand Down
2 changes: 1 addition & 1 deletion core/test/config/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ginkgo_create_test(config)
ginkgo_create_test(property_tree)
ginkgo_create_test(registry)
ginkgo_create_test(solver)
ginkgo_create_test(solver)
Loading

0 comments on commit 2d5fed4

Please sign in to comment.