Skip to content

Commit

Permalink
fix compile error
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed May 9, 2024
1 parent 4e81126 commit 260a9c7
Show file tree
Hide file tree
Showing 29 changed files with 112 additions and 112 deletions.
7 changes: 4 additions & 3 deletions core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ get_value(const pnode& config)
* This is specialization for integral type
*/
template <typename IndexType>
inline
typename std::enable_if<std::is_integral<IndexType>::value, IndexType>::type
get_value(const pnode& config)
inline typename std::enable_if<std::is_integral<IndexType>::value &&
!std::is_same<IndexType, bool>::value,
IndexType>::type
get_value(const pnode& config)
{
auto val = config.get_integer();
GKO_THROW_IF_INVALID(
Expand Down
61 changes: 29 additions & 32 deletions core/config/solver_config.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017-2023 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -19,7 +19,7 @@
#include <ginkgo/core/solver/triangular.hpp>


#include "core/config/config.hpp"
#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"
#include "core/config/solver_config.hpp"

Expand All @@ -28,18 +28,18 @@ namespace gko {
namespace config {

// for valuetype only
#define PARSE(_type) \
template <> \
deferred_factory_parameter<gko::LinOpFactory> \
parse<LinOpFactoryType::_type>(const pnode& config, \
const registry& context, \
gko::config::type_descriptor td) \
{ \
auto updated = update_type(config, td); \
return dispatch<gko::LinOpFactory, gko::solver::_type>( \
config, context, updated, \
make_type_selector(updated.get_value_typestr(), \
value_type_list())); \
#define PARSE(_type) \
template <> \
deferred_factory_parameter<gko::LinOpFactory> \
parse<LinOpFactoryType::_type>(const pnode& config, \
const registry& context, \
const type_descriptor& td) \
{ \
auto updated = update_type(config, td); \
return dispatch<gko::LinOpFactory, gko::solver::_type>( \
config, context, updated, \
make_type_selector(updated.get_value_typestr(), \
value_type_list())); \
}

PARSE(Cg)
Expand All @@ -55,40 +55,37 @@ PARSE(CbGmres)


template <>
deferred_factory_parameter<gko::LinOpFactory>
parse<LinOpFactoryType::Direct>(const pnode& config,
const registry& context,
gko::config::type_descriptor td)
deferred_factory_parameter<gko::LinOpFactory> parse<LinOpFactoryType::Direct>(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto updated = update_type(config, td);
return dispatch<gko::LinOpFactory, gko::experimental::solver::Direct>(
config, context, updated, make_type_selector(updated.get_value_typestr(), value_type_list()), make_type_selector(updated.get_index_typestr(), index_type_list())
);
config, context, updated,
make_type_selector(updated.get_value_typestr(), value_type_list()),
make_type_selector(updated.get_index_typestr(), index_type_list()));
}


template <>
deferred_factory_parameter<gko::LinOpFactory>
parse<LinOpFactoryType::LowerTrs>(const pnode& config,
const registry& context,
gko::config::type_descriptor td)
deferred_factory_parameter<gko::LinOpFactory> parse<LinOpFactoryType::LowerTrs>(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto updated = update_type(config, td);
return dispatch<gko::LinOpFactory, gko::solver::LowerTrs>(
config, context, updated,
make_type_selector(updated.get_value_typestr(), value_type_list()), make_type_selector(updated.get_index_typestr(), index_type_list()));
config, context, updated,
make_type_selector(updated.get_value_typestr(), value_type_list()),
make_type_selector(updated.get_index_typestr(), index_type_list()));
}

template <>
deferred_factory_parameter<gko::LinOpFactory>
parse<LinOpFactoryType::UpperTrs>(const pnode& config,
const registry& context,
gko::config::type_descriptor td)
deferred_factory_parameter<gko::LinOpFactory> parse<LinOpFactoryType::UpperTrs>(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto updated = update_type(config, td);
return dispatch<gko::LinOpFactory, gko::solver::UpperTrs>(
config, context, updated,
make_type_selector(updated.get_value_typestr(), value_type_list()), make_type_selector(updated.get_index_typestr(), index_type_list()));
config, context, updated,
make_type_selector(updated.get_value_typestr(), value_type_list()),
make_type_selector(updated.get_index_typestr(), index_type_list()));
}


Expand Down
2 changes: 1 addition & 1 deletion core/config/solver_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <ginkgo/core/config/registry.hpp>


#include "core/config/config.hpp"
#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"

namespace gko {
Expand Down
2 changes: 1 addition & 1 deletion core/solver/bicg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ GKO_REGISTER_OPERATION(step_2, bicg::step_2);
template <typename ValueType>
typename Bicg<ValueType>::parameters_type Bicg<ValueType>::parse(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
const config::type_descriptor& td_for_child)
{
auto factory = solver::Bicg<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
Expand Down
7 changes: 3 additions & 4 deletions core/solver/bicgstab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ GKO_REGISTER_OPERATION(finalize, bicgstab::finalize);


template <typename ValueType>
typename Bicgstab<ValueType>::parameters_type
Bicgstab<ValueType>::parse(const config::pnode& config,
const config::registry& context,
config::type_descriptor td_for_child)
typename Bicgstab<ValueType>::parameters_type Bicgstab<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = solver::Bicgstab<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
Expand Down
7 changes: 3 additions & 4 deletions core/solver/cb_gmres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,9 @@ struct helper<std::complex<T>> {


template <typename ValueType>
typename CbGmres<ValueType>::parameters_type
CbGmres<ValueType>::parse(const config::pnode& config,
const config::registry& context,
config::type_descriptor td_for_child)
typename CbGmres<ValueType>::parameters_type CbGmres<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = solver::CbGmres<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
Expand Down
2 changes: 1 addition & 1 deletion core/solver/cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ GKO_REGISTER_OPERATION(step_2, cg::step_2);
template <typename ValueType>
typename Cg<ValueType>::parameters_type Cg<ValueType>::parse(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
const config::type_descriptor& td_for_child)
{
auto factory = solver::Cg<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
Expand Down
2 changes: 1 addition & 1 deletion core/solver/cgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ GKO_REGISTER_OPERATION(step_3, cgs::step_3);
template <typename ValueType>
typename Cgs<ValueType>::parameters_type Cgs<ValueType>::parse(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
const config::type_descriptor& td_for_child)
{
auto factory = solver::Cgs<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
Expand Down
8 changes: 4 additions & 4 deletions core/solver/direct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <ginkgo/core/solver/solver_base.hpp>


#include "core/config/config.hpp"
#include "core/config/config_helper.hpp"


namespace gko {
Expand All @@ -23,9 +23,9 @@ namespace solver {

template <typename ValueType, typename IndexType>
typename Direct<ValueType, IndexType>::parameters_type
Direct<ValueType, IndexType>::parse(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
Direct<ValueType, IndexType>::parse(const config::pnode& config,
const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = Direct<ValueType, IndexType>::build();
if (auto& obj = config.get("num_rhs")) {
Expand Down
2 changes: 1 addition & 1 deletion core/solver/fcg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ GKO_REGISTER_OPERATION(step_2, fcg::step_2);
template <typename ValueType>
typename Fcg<ValueType>::parameters_type Fcg<ValueType>::parse(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
const config::type_descriptor& td_for_child)
{
auto factory = solver::Fcg<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
Expand Down
2 changes: 1 addition & 1 deletion core/solver/gcr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ GKO_REGISTER_OPERATION(step_1, gcr::step_1);
template <typename ValueType>
typename Gcr<ValueType>::parameters_type Gcr<ValueType>::parse(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
const config::type_descriptor& td_for_child)
{
auto factory = solver::Gcr<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
Expand Down
2 changes: 1 addition & 1 deletion core/solver/gmres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ GKO_REGISTER_OPERATION(multi_axpy, gmres::multi_axpy);
template <typename ValueType>
typename Gmres<ValueType>::parameters_type Gmres<ValueType>::parse(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
const config::type_descriptor& td_for_child)
{
auto factory = solver::Gmres<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
Expand Down
2 changes: 1 addition & 1 deletion core/solver/idr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ GKO_REGISTER_OPERATION(compute_omega, idr::compute_omega);
template <typename ValueType>
typename Idr<ValueType>::parameters_type Idr<ValueType>::parse(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
const config::type_descriptor& td_for_child)
{
auto factory = solver::Idr<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
Expand Down
6 changes: 3 additions & 3 deletions core/solver/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <ginkgo/core/solver/solver_base.hpp>


#include "core/config/config.hpp"
#include "core/config/config_helper.hpp"
#include "core/distributed/helpers.hpp"
#include "core/solver/ir_kernels.hpp"
#include "core/solver/solver_base.hpp"
Expand All @@ -33,7 +33,7 @@ GKO_REGISTER_OPERATION(initialize, ir::initialize);
template <typename ValueType>
typename Ir<ValueType>::parameters_type Ir<ValueType>::parse(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
const config::type_descriptor& td_for_child)
{
auto factory = solver::Ir<ValueType>::build();
if (auto& obj = config.get("criteria")) {
Expand All @@ -47,7 +47,7 @@ typename Ir<ValueType>::parameters_type Ir<ValueType>::parse(
}
if (auto& obj = config.get("generated_solver")) {
factory.with_generated_solver(
gko::config::get_pointer<const LinOp>(obj, context, td_for_child));
gko::config::get_stored_obj<const LinOp>(obj, context));
}
if (auto& obj = config.get("relaxation_factor")) {
factory.with_relaxation_factor(gko::config::get_value<ValueType>(obj));
Expand Down
4 changes: 2 additions & 2 deletions core/solver/lower_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include <ginkgo/core/solver/triangular.hpp>


#include "core/config/config.hpp"
#include "core/config/config_helper.hpp"
#include "core/solver/lower_trs_kernels.hpp"


Expand All @@ -38,7 +38,7 @@ template <typename ValueType, typename IndexType>
typename LowerTrs<ValueType, IndexType>::parameters_type
LowerTrs<ValueType, IndexType>::parse(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
const config::type_descriptor& td_for_child)
{
auto factory = LowerTrs<ValueType, IndexType>::build();
// duplicate?
Expand Down
4 changes: 2 additions & 2 deletions core/solver/upper_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include <ginkgo/core/solver/triangular.hpp>


#include "core/config/config.hpp"
#include "core/config/config_helper.hpp"
#include "core/solver/upper_trs_kernels.hpp"


Expand All @@ -38,7 +38,7 @@ template <typename ValueType, typename IndexType>
typename UpperTrs<ValueType, IndexType>::parameters_type
UpperTrs<ValueType, IndexType>::parse(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
const config::type_descriptor& td_for_child)
{
auto factory = UpperTrs<ValueType, IndexType>::build();
// duplicate?
Expand Down
29 changes: 18 additions & 11 deletions core/test/config/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
#include <ginkgo/core/stop/iteration.hpp>


#include "core/config/config.hpp"
#include "core/config/config_helper.hpp"
#include "core/config/registry_accessor.hpp"
#include "core/test/utils.hpp"


Expand Down Expand Up @@ -56,14 +57,16 @@ struct SolverConfigTest {
{
config_map["generated_preconditioner"] = pnode{"linop"};
param.with_generated_preconditioner(
reg.search_data<gko::LinOp>("linop"));
detail::registry_accessor::get_data<gko::LinOp>(reg, "linop"));
if (from_reg) {
config_map["criteria"] = pnode{"criterion_factory"};
param.with_criteria(reg.search_data<gko::stop::CriterionFactory>(
"criterion_factory"));
param.with_criteria(
detail::registry_accessor::get_data<
gko::stop::CriterionFactory>(reg, "criterion_factory"));
config_map["preconditioner"] = pnode{"linop_factory"};
param.with_preconditioner(
reg.search_data<gko::LinOpFactory>("linop_factory"));
detail::registry_accessor::get_data<gko::LinOpFactory>(
reg, "linop_factory"));
} else {
config_map["criteria"] = pnode{
std::map<std::string, pnode>{{"Type", pnode{"Iteration"}}}};
Expand Down Expand Up @@ -154,18 +157,21 @@ struct Ir : SolverConfigTest<gko::solver::Ir<float>, gko::solver::Ir<double>> {
std::shared_ptr<const gko::Executor> exec)
{
config_map["generated_solver"] = pnode{"linop"};
param.with_generated_solver(reg.search_data<gko::LinOp>("linop"));
param.with_generated_solver(
detail::registry_accessor::get_data<gko::LinOp>(reg, "linop"));
config_map["relaxation_factor"] = pnode{1.2};
param.with_relaxation_factor(decltype(param.relaxation_factor){1.2});
config_map["default_initial_guess"] = pnode{"zero"};
param.with_default_initial_guess(gko::solver::initial_guess_mode::zero);
if (from_reg) {
config_map["criteria"] = pnode{"criterion_factory"};
param.with_criteria(reg.search_data<gko::stop::CriterionFactory>(
"criterion_factory"));
param.with_criteria(
detail::registry_accessor::get_data<
gko::stop::CriterionFactory>(reg, "criterion_factory"));
config_map["solver"] = pnode{"linop_factory"};
param.with_solver(
reg.search_data<gko::LinOpFactory>("linop_factory"));
detail::registry_accessor::get_data<gko::LinOpFactory>(
reg, "linop_factory"));
} else {
config_map["criteria"] = pnode{
std::map<std::string, pnode>{{"Type", pnode{"Iteration"}}}};
Expand Down Expand Up @@ -358,7 +364,8 @@ struct Direct
if (from_reg) {
config_map["factorization"] = pnode{"linop_factory"};
param.with_factorization(
reg.search_data<gko::LinOpFactory>("linop_factory"));
detail::registry_accessor::get_data<gko::LinOpFactory>(
reg, "linop_factory"));
} else {
config_map["factorization"] =
pnode{{{"Type", pnode{"Cg"}}, {"ValueType", pnode{"double"}}}};
Expand Down Expand Up @@ -445,7 +452,7 @@ class Solver : public ::testing::Test {
solver_factory(DummySolver::build().on(exec)),
stop_factory(DummyStop::build().on(exec)),
td("double", "int"),
reg(generate_config_map())
reg()
{
reg.emplace("linop", mtx);
reg.emplace("linop_factory", solver_factory);
Expand Down
3 changes: 0 additions & 3 deletions include/ginkgo/core/config/registry.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,6 @@ class registry final {
* {"cg", cg_shared_ptr}
* }}
* ```
* @param build_map the build map to dispatch the class base. Ginkgo
* provides `generate_config_map()` in config.hpp to provide the ginkgo
* build map. Users can extend this map to fit their own LinOpFactory.
*/
registry(
const std::unordered_map<std::string, detail::allowed_ptr>& stored_map,
Expand Down
6 changes: 3 additions & 3 deletions include/ginkgo/core/solver/bicg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ class Bicg
GKO_ENABLE_LIN_OP_FACTORY(Bicg, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);

static parameters_type parse(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child);
static parameters_type parse(const config::pnode& config,
const config::registry& context,
const config::type_descriptor& td_for_child);

protected:
void apply_impl(const LinOp* b, LinOp* x) const override;
Expand Down
Loading

0 comments on commit 260a9c7

Please sign in to comment.