diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index ac15a869e79..341983e72fd 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -19,7 +19,12 @@ target_sources(ginkgo base/segmented_array.cpp base/timer.cpp base/version.cpp + config/config.cpp + config/config_helper.cpp config/property_tree.cpp + config/registry.cpp + config/stop_config.cpp + config/type_descriptor.cpp distributed/index_map.cpp distributed/partition.cpp factorization/cholesky.cpp diff --git a/core/config/config.cpp b/core/config/config.cpp new file mode 100644 index 00000000000..8c97c0038ed --- /dev/null +++ b/core/config/config.cpp @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + + +#include + + +#include +#include + + +#include "core/config/config_helper.hpp" +#include "core/config/registry_accessor.hpp" + + +namespace gko { +namespace config { + + +deferred_factory_parameter parse(const pnode& config, + const registry& context, + const type_descriptor& td) +{ + if (auto& obj = config.get("type")) { + auto func = detail::registry_accessor::get_build_map(context).at( + obj.get_string()); + return func(config, context, td); + } + GKO_INVALID_STATE("Should contain type property"); +} + + +} // namespace config +} // namespace gko diff --git a/core/config/config_helper.cpp b/core/config/config_helper.cpp new file mode 100644 index 00000000000..5f26d927cad --- /dev/null +++ b/core/config/config_helper.cpp @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + + +#include +#include +#include + + +#include "core/config/config_helper.hpp" +#include "core/config/registry_accessor.hpp" +#include "core/config/stop_config.hpp" + +namespace gko { +namespace config { + + +template <> +deferred_factory_parameter +parse_or_get_factory(const pnode& config, + const registry& context, + const type_descriptor& td) +{ + if (config.get_tag() == pnode::tag_t::string) { + return detail::registry_accessor::get_data( + context, config.get_string()); + } else if (config.get_tag() == pnode::tag_t::map) { + return parse(config, context, td); + } else { + GKO_INVALID_STATE("The data of config is not valid."); + } +} + + +template <> +deferred_factory_parameter +parse_or_get_factory(const pnode& config, + const registry& context, + const type_descriptor& td) +{ + if (config.get_tag() == pnode::tag_t::string) { + return detail::registry_accessor::get_data( + context, config.get_string()); + } else if (config.get_tag() == pnode::tag_t::map) { + static std::map( + const pnode&, const registry&, type_descriptor)>> + criterion_map{ + {{"Time", configure_time}, + {"Iteration", configure_iter}, + {"ResidualNorm", configure_residual}, + {"ImplicitResidualNorm", configure_implicit_residual}}}; + return criterion_map.at(config.get("type").get_string())(config, + context, td); + } else { + GKO_INVALID_STATE("The data of config is not valid."); + } +} + +} // namespace config +} // namespace gko diff --git a/core/config/config_helper.hpp b/core/config/config_helper.hpp new file mode 100644 index 00000000000..798d3623856 --- /dev/null +++ b/core/config/config_helper.hpp @@ -0,0 +1,180 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CORE_CONFIG_CONFIG_HELPER_HPP_ +#define GKO_CORE_CONFIG_CONFIG_HELPER_HPP_ + + +#include +#include + + +#include +#include +#include +#include +#include +#include +#include + + +#include "core/config/registry_accessor.hpp" + + +namespace gko { +namespace config { + + +/** + * LinOpFactoryType enum is to avoid forward declaration, linopfactory header, + * two template versions of parse + */ +enum class LinOpFactoryType : int { Cg = 0 }; + + +/** + * It is only an intermediate step after dispatching the class base type. Each + * implementation needs to deal with the template selection. + */ +template +deferred_factory_parameter parse( + const pnode& config, const registry& context, + const type_descriptor& td = make_type_descriptor<>()); + + +/** + * get_stored_obj searches the object pointer stored in the registry by string + */ +template +inline std::shared_ptr get_stored_obj(const pnode& config, + const registry& context) +{ + std::shared_ptr ptr; + using T_non_const = std::remove_const_t; + ptr = detail::registry_accessor::get_data(context, + config.get_string()); + GKO_THROW_IF_INVALID(ptr.get() != nullptr, "Do not get the stored data"); + return ptr; +} + + +/** + * Build the factory from config (map) or search the pointers in + * the registry by string. + */ +template +deferred_factory_parameter parse_or_get_factory(const pnode& config, + const registry& context, + const type_descriptor& td); + +/** + * specialize for const LinOpFactory + */ +template <> +deferred_factory_parameter +parse_or_get_factory(const pnode& config, + const registry& context, + const type_descriptor& td); + +/** + * specialize for const stop::CriterionFactory + */ +template <> +deferred_factory_parameter +parse_or_get_factory(const pnode& config, + const registry& context, + const type_descriptor& td); + +/** + * give a vector of factory by calling parse_or_get_factory. + */ +template +inline std::vector> parse_or_get_factory_vector( + const pnode& config, const registry& context, const type_descriptor& td) +{ + std::vector> res; + if (config.get_tag() == pnode::tag_t::array) { + for (const auto& it : config.get_array()) { + res.push_back(parse_or_get_factory(it, context, td)); + } + } else { + // only one config can be passed without array + res.push_back(parse_or_get_factory(config, context, td)); + } + + return res; +} + + +/** + * get_value gets the corresponding type value from config. + * + * This is specialization for integral type + */ +template +inline std::enable_if_t::value, IndexType> +get_value(const pnode& config) +{ + auto val = config.get_integer(); + GKO_THROW_IF_INVALID( + val <= std::numeric_limits::max() && + val >= std::numeric_limits::min(), + "the config value is out of the range of the require type."); + return static_cast(val); +} + + +/** + * get_value gets the corresponding type value from config. + * + * This is specialization for floating point type + */ +template +inline std::enable_if_t::value, ValueType> +get_value(const pnode& config) +{ + auto val = config.get_real(); + // the max, min of floating point only consider positive value. + GKO_THROW_IF_INVALID( + val <= std::numeric_limits::max() && + val >= -std::numeric_limits::max(), + "the config value is out of the range of the require type."); + return static_cast(val); +} + +/** + * get_value gets the corresponding type value from config. + * + * This is specialization for complex type + */ +template +inline std::enable_if_t::value, ValueType> +get_value(const pnode& config) +{ + using real_type = gko::remove_complex; + if (config.get_tag() == pnode::tag_t::real) { + return static_cast(get_value(config)); + } else if (config.get_tag() == pnode::tag_t::array) { + real_type real(0); + real_type imag(0); + if (config.get_array().size() >= 1) { + real = get_value(config.get(0)); + } + if (config.get_array().size() >= 2) { + imag = get_value(config.get(1)); + } + GKO_THROW_IF_INVALID( + config.get_array().size() <= 2, + "complex value array expression only accept up to two elements"); + return ValueType{real, imag}; + } + GKO_INVALID_STATE("Can not get complex value"); +} + + +} // namespace config +} // namespace gko + + +#endif // GKO_CORE_CONFIG_CONFIG_HELPER_HPP_ diff --git a/core/config/dispatch.hpp b/core/config/dispatch.hpp new file mode 100644 index 00000000000..c765150f72a --- /dev/null +++ b/core/config/dispatch.hpp @@ -0,0 +1,113 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CORE_CONFIG_DISPATCH_HPP_ +#define GKO_CORE_CONFIG_DISPATCH_HPP_ + + +#include +#include + + +#include +#include +#include +#include +#include + + +#include "core/config/config_helper.hpp" +#include "core/config/type_descriptor_helper.hpp" + + +namespace gko { +namespace config { + + +/** + * type_selector connect the runtime string and the allowed type list together + */ +template +struct type_selector { + explicit type_selector(const std::string& input) : runtime(input) {} + + std::string runtime; +}; + + +/** + * It is the helper function to create type_selector with the type_list as the + * argument. + */ +template +type_selector make_type_selector(const std::string& runtime_type, + syn::type_list) +{ + return type_selector{runtime_type}; +} + + +/** + * This function is to dispatch the type from runtime parameter. + * The concrete class need to have static member function + * parse(pnode, registry, type_descriptor) + */ +template class Base, + typename... Types> +deferred_factory_parameter dispatch(const pnode& config, + const registry& context, + const type_descriptor& td) +{ + return Base::parse(config, context, td); +} + +// When the dispatch does not find match from the given list. +template class Base, + typename... Types, typename... Rest> +deferred_factory_parameter dispatch(const pnode& config, + const registry& context, + const type_descriptor& td, + type_selector<> selector, + Rest... rest) +{ + GKO_INVALID_STATE("The provided runtime type >" + selector.runtime + + "< doesn't match any of the allowed compile time types."); +} + +/** + * This function is to dispatch the type from runtime parameter. + * The concrete class need to have static member function + * `parse(pnode, registry, type_descriptor)` + * + * @param config the configuration + * @param context the registry context + * @param td the default type descriptor + * @param selector the current dispatching type_selector + * @param rest... the type_selector list for the rest + */ +template class Base, + typename... Types, typename S, typename... AllowedTypes, + typename... Rest> +deferred_factory_parameter dispatch( + const pnode& config, const registry& context, const type_descriptor& td, + type_selector selector, Rest... rest) +{ + if (selector.runtime == type_string::str()) { + return dispatch(config, context, td, + rest...); + } else { + return dispatch( + config, context, td, + type_selector(selector.runtime), rest...); + } +} + +using value_type_list = + syn::type_list, std::complex>; + + +} // namespace config +} // namespace gko + +#endif // GKO_CORE_CONFIG_DISPATCH_HPP_ diff --git a/core/config/registry.cpp b/core/config/registry.cpp new file mode 100644 index 00000000000..1113adb93f4 --- /dev/null +++ b/core/config/registry.cpp @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + + +#include +#include + + +#include "core/config/config_helper.hpp" + + +namespace gko { +namespace config { + + +configuration_map generate_config_map() +{ + return {{"solver::Cg", parse}}; +} + + +registry::registry(const configuration_map& additional_map) + : registry({}, additional_map) +{} + + +registry::registry( + const std::unordered_map& stored_map, + const configuration_map& additional_map) + : stored_map_(stored_map), build_map_(generate_config_map()) +{ + // merge additional_map into build_map_ + for (auto& item : additional_map) { + auto res = build_map_.emplace(item.first, item.second); + GKO_THROW_IF_INVALID(res.second, + "failed when adding the key " + item.first); + } +} + + +} // namespace config +} // namespace gko diff --git a/core/config/registry_accessor.hpp b/core/config/registry_accessor.hpp new file mode 100644 index 00000000000..002e6245811 --- /dev/null +++ b/core/config/registry_accessor.hpp @@ -0,0 +1,41 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CORE_CONFIG_REGISTRY_ACCESSOR_HPP_ +#define GKO_CORE_CONFIG_REGISTRY_ACCESSOR_HPP_ + + +#include + + +#include + + +namespace gko { +namespace config { +namespace detail { + + +class registry_accessor { +public: + template + static inline std::shared_ptr get_data(const registry& reg, + std::string key) + { + return reg.get_data(key); + } + + static inline const configuration_map& get_build_map(const registry& reg) + { + return reg.get_build_map(); + } +}; + + +} // namespace detail +} // namespace config +} // namespace gko + + +#endif // GKO_CORE_CONFIG_REGISTRY_ACCESSOR_HPP_ diff --git a/core/config/stop_config.cpp b/core/config/stop_config.cpp new file mode 100644 index 00000000000..63148cbfcd9 --- /dev/null +++ b/core/config/stop_config.cpp @@ -0,0 +1,127 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#include "core/config/config_helper.hpp" +#include "core/config/dispatch.hpp" +#include "core/config/registry_accessor.hpp" +#include "core/config/stop_config.hpp" +#include "core/config/type_descriptor_helper.hpp" + + +namespace gko { +namespace config { + + +deferred_factory_parameter configure_time( + const pnode& config, const registry& context, const type_descriptor& td) +{ + auto factory = stop::Time::build(); + if (auto& obj = config.get("time_limit")) { + factory.with_time_limit(gko::config::get_value(obj)); + } + return factory; +} + + +deferred_factory_parameter configure_iter( + const pnode& config, const registry& context, const type_descriptor& td) +{ + auto factory = stop::Iteration::build(); + if (auto& obj = config.get("max_iters")) { + factory.with_max_iters(gko::config::get_value(obj)); + } + return factory; +} + + +inline stop::mode get_mode(const std::string& str) +{ + if (str == "absolute") { + return stop::mode::absolute; + } else if (str == "initial_resnorm") { + return stop::mode::initial_resnorm; + } else if (str == "rhs_norm") { + return stop::mode::rhs_norm; + } + GKO_INVALID_STATE("Not valid " + str); +} + + +template +class ResidualNormConfigurer { +public: + static deferred_factory_parameter< + typename stop::ResidualNorm::Factory> + parse(const gko::config::pnode& config, + const gko::config::registry& context, + const gko::config::type_descriptor& td_for_child) + { + auto params = stop::ResidualNorm::build(); + if (auto& obj = config.get("reduction_factor")) { + params.with_reduction_factor( + gko::config::get_value>(obj)); + } + if (auto& obj = config.get("baseline")) { + params.with_baseline(get_mode(obj.get_string())); + } + return params; + } +}; + + +deferred_factory_parameter configure_residual( + const pnode& config, const registry& context, const type_descriptor& td) +{ + auto updated = update_type(config, td); + return dispatch( + config, context, updated, + make_type_selector(updated.get_value_typestr(), value_type_list())); +} + + +template +class ImplicitResidualNormConfigurer { +public: + static deferred_factory_parameter< + typename stop::ImplicitResidualNorm::Factory> + parse(const gko::config::pnode& config, + const gko::config::registry& context, + const gko::config::type_descriptor& td_for_child) + { + auto params = stop::ImplicitResidualNorm::build(); + if (auto& obj = config.get("reduction_factor")) { + params.with_reduction_factor( + gko::config::get_value>(obj)); + } + if (auto& obj = config.get("baseline")) { + params.with_baseline(get_mode(obj.get_string())); + } + return params; + } +}; + + +deferred_factory_parameter configure_implicit_residual( + const pnode& config, const registry& context, const type_descriptor& td) +{ + auto updated = update_type(config, td); + return dispatch( + config, context, updated, + make_type_selector(updated.get_value_typestr(), value_type_list())); +} + + +} // namespace config +} // namespace gko diff --git a/core/config/stop_config.hpp b/core/config/stop_config.hpp new file mode 100644 index 00000000000..edb99c6a634 --- /dev/null +++ b/core/config/stop_config.hpp @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CORE_CONFIG_STOP_CONFIG_HPP_ +#define GKO_CORE_CONFIG_STOP_CONFIG_HPP_ + + +#include +#include +#include +#include + + +namespace gko { +namespace config { + + +deferred_factory_parameter configure_time( + const pnode& config, const registry& context, const type_descriptor& td); + +deferred_factory_parameter configure_iter( + const pnode& config, const registry& context, const type_descriptor& td); + +deferred_factory_parameter configure_residual( + const pnode& config, const registry& context, const type_descriptor& td); + +deferred_factory_parameter configure_implicit_residual( + const pnode& config, const registry& context, const type_descriptor& td); + +} // namespace config +} // namespace gko + + +#endif // GKO_CORE_CONFIG_STOP_CONFIG_HPP_ diff --git a/core/config/type_descriptor.cpp b/core/config/type_descriptor.cpp new file mode 100644 index 00000000000..c2885407cad --- /dev/null +++ b/core/config/type_descriptor.cpp @@ -0,0 +1,76 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + + +#include + + +#include "core/config/type_descriptor_helper.hpp" + + +namespace gko { +namespace config { + + +type_descriptor update_type(const pnode& config, const type_descriptor& td) +{ + auto value_typestr = td.get_value_typestr(); + auto index_typestr = td.get_index_typestr(); + + if (auto& obj = config.get("value_type")) { + value_typestr = obj.get_string(); + } + if (auto& obj = config.get("index_type")) { + GKO_INVALID_STATE( + "Setting index_type in the config is not allowed. Please set the " + "proper index_type through type_descriptor of parse"); + } + return type_descriptor{value_typestr, index_typestr}; +} + + +template +type_descriptor make_type_descriptor() +{ + return type_descriptor{type_string::str(), + type_string::str()}; +} + +template type_descriptor make_type_descriptor(); +template type_descriptor make_type_descriptor(); +template type_descriptor make_type_descriptor(); +template type_descriptor make_type_descriptor, void>(); +template type_descriptor make_type_descriptor, void>(); +template type_descriptor make_type_descriptor(); +template type_descriptor make_type_descriptor(); +template type_descriptor make_type_descriptor(); +template type_descriptor make_type_descriptor, int32>(); +template type_descriptor make_type_descriptor, int32>(); +template type_descriptor make_type_descriptor(); +template type_descriptor make_type_descriptor(); +template type_descriptor make_type_descriptor(); +template type_descriptor make_type_descriptor, int64>(); +template type_descriptor make_type_descriptor, int64>(); + + +type_descriptor::type_descriptor(std::string value_typestr, + std::string index_typestr) + : value_typestr_(value_typestr), index_typestr_(index_typestr) +{} + +const std::string& type_descriptor::get_value_typestr() const +{ + return value_typestr_; +} + +const std::string& type_descriptor::get_index_typestr() const +{ + return index_typestr_; +} + + +} // namespace config +} // namespace gko diff --git a/core/config/type_descriptor_helper.hpp b/core/config/type_descriptor_helper.hpp new file mode 100644 index 00000000000..3917e317773 --- /dev/null +++ b/core/config/type_descriptor_helper.hpp @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CORE_CONFIG_TYPE_DESCRIPTOR_HELPER_HPP_ +#define GKO_CORE_CONFIG_TYPE_DESCRIPTOR_HELPER_HPP_ + + +#include + + +#include +#include +#include + + +namespace gko { +namespace config { + + +/** + This function updates the default type setting from current config. Any type + that is not specified in the config will fall back to the type stored in the + current type_descriptor. + */ +type_descriptor update_type(const pnode& config, const type_descriptor& td); + + +// type_string providing the mapping from type to string. +template +struct type_string {}; + +#define TYPE_STRING_OVERLOAD(_type, _str) \ + template <> \ + struct type_string<_type> { \ + static std::string str() { return _str; } \ + } + +TYPE_STRING_OVERLOAD(void, "void"); +TYPE_STRING_OVERLOAD(double, "float64"); +TYPE_STRING_OVERLOAD(float, "float32"); +TYPE_STRING_OVERLOAD(std::complex, "complex"); +TYPE_STRING_OVERLOAD(std::complex, "complex"); +TYPE_STRING_OVERLOAD(int32, "int32"); +TYPE_STRING_OVERLOAD(int64, "int64"); + +#undef TYPE_STRING_OVERLOAD + + +} // namespace config +} // namespace gko + + +#endif // GKO_CORE_CONFIG_TYPE_DESCRIPTOR_HELPER_HPP_ diff --git a/core/solver/cg.cpp b/core/solver/cg.cpp index e445cfcafaf..71e5fcfbb3b 100644 --- a/core/solver/cg.cpp +++ b/core/solver/cg.cpp @@ -12,14 +12,36 @@ #include #include #include +#include +#include +#include "core/config/config_helper.hpp" +#include "core/config/dispatch.hpp" +#include "core/config/type_descriptor_helper.hpp" #include "core/distributed/helpers.hpp" #include "core/solver/cg_kernels.hpp" #include "core/solver/solver_boilerplate.hpp" namespace gko { +namespace config { + + +template <> +deferred_factory_parameter parse( + const pnode& config, const registry& context, const type_descriptor& td) +{ + auto updated = update_type(config, td); + return dispatch( + config, context, updated, + make_type_selector(updated.get_value_typestr(), value_type_list())); +} + + +} // namespace config + + namespace solver { namespace cg { namespace { @@ -34,6 +56,31 @@ GKO_REGISTER_OPERATION(step_2, cg::step_2); } // namespace cg +template +typename Cg::parameters_type Cg::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = solver::Cg::build(); + // The following will be moved to the common solver function in another pr + if (auto& obj = config.get("generated_preconditioner")) { + params.with_generated_preconditioner( + gko::config::get_stored_obj(obj, context)); + } + if (auto& obj = config.get("criteria")) { + params.with_criteria( + gko::config::parse_or_get_factory_vector< + const stop::CriterionFactory>(obj, context, td_for_child)); + } + if (auto& obj = config.get("preconditioner")) { + params.with_preconditioner( + gko::config::parse_or_get_factory( + obj, context, td_for_child)); + } + return params; +} + + template std::unique_ptr Cg::transpose() const { diff --git a/core/test/config/CMakeLists.txt b/core/test/config/CMakeLists.txt index e842152634c..a8783cd4a20 100644 --- a/core/test/config/CMakeLists.txt +++ b/core/test/config/CMakeLists.txt @@ -1 +1,3 @@ +ginkgo_create_test(config) ginkgo_create_test(property_tree) +ginkgo_create_test(registry) diff --git a/core/test/config/config.cpp b/core/test/config/config.cpp new file mode 100644 index 00000000000..163f6936de2 --- /dev/null +++ b/core/test/config/config.cpp @@ -0,0 +1,191 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + + +#include + + +#include +#include +#include +#include +#include +#include +#include + + +#include "core/config/config_helper.hpp" +#include "core/test/utils.hpp" + + +namespace { + + +using namespace gko::config; + + +class Config : public ::testing::Test { +protected: + using value_type = double; + using Mtx = gko::matrix::Dense; + Config() + : exec(gko::ReferenceExecutor::create()), + mtx(gko::initialize( + {{2, -1.0, 0.0}, {-1.0, 2, -1.0}, {0.0, -1.0, 2}}, exec)), + stop_config({{"type", pnode{"Iteration"}}, {"max_iters", pnode{1}}}) + {} + + std::shared_ptr exec; + std::shared_ptr mtx; + pnode stop_config; +}; + + +TEST_F(Config, GenerateObjectWithoutDefault) +{ + auto reg = registry(); + + pnode p{ + {{"value_type", pnode{"float64"}}, {"criteria", this->stop_config}}}; + auto obj = parse(p, reg).on(this->exec); + + ASSERT_NE(dynamic_cast::Factory*>(obj.get()), + nullptr); +} + + +TEST_F(Config, GenerateObjectWithData) +{ + auto reg = registry(); + reg.emplace("precond", this->mtx); + + pnode p{{{"generated_preconditioner", pnode{"precond"}}, + {"criteria", this->stop_config}}}; + auto obj = + parse(p, reg, type_descriptor{"float32", "void"}) + .on(this->exec); + + ASSERT_NE(dynamic_cast::Factory*>(obj.get()), + nullptr); + ASSERT_NE(dynamic_cast::Factory*>(obj.get()) + ->get_parameters() + .generated_preconditioner, + nullptr); +} + + +TEST_F(Config, GenerateObjectWithPreconditioner) +{ + auto reg = registry(); + auto precond_node = + pnode{{{"type", pnode{"solver::Cg"}}, {"criteria", this->stop_config}}}; + pnode p{{{"value_type", pnode{"float64"}}, + {"criteria", this->stop_config}, + {"preconditioner", precond_node}}}; + + auto obj = parse(p, reg).on(this->exec); + + ASSERT_NE(dynamic_cast::Factory*>(obj.get()), + nullptr); + ASSERT_NE(dynamic_cast::Factory*>(obj.get()) + ->get_parameters() + .preconditioner, + nullptr); +} + + +TEST_F(Config, GenerateObjectWithCustomBuild) +{ + configuration_map config_map; + config_map["Custom"] = [](const pnode& config, const registry& context, + const type_descriptor& td_for_child) { + return gko::solver::Bicg::build().with_criteria( + gko::stop::Iteration::build().with_max_iters(2u)); + }; + auto reg = registry(config_map); + auto precond_node = + pnode{std::map{{"type", pnode{"Custom"}}}}; + pnode p{{{"value_type", pnode{"float64"}}, + {"criteria", this->stop_config}, + {"preconditioner", precond_node}}}; + + auto obj = + parse(p, reg, type_descriptor{"float64", "void"}) + .on(this->exec); + + ASSERT_NE(dynamic_cast::Factory*>(obj.get()), + nullptr); + ASSERT_NE(dynamic_cast::Factory*>( + dynamic_cast::Factory*>(obj.get()) + ->get_parameters() + .preconditioner.get()), + nullptr); +} + + +TEST(GetValue, IndexType) +{ + long long int value = 123; + pnode config{value}; + + ASSERT_EQ(get_value(config), value); + ASSERT_EQ(get_value(config), value); + ASSERT_EQ(get_value(config), value); + ASSERT_EQ(get_value(config), value); + testing::StaticAssertTypeEq(config)), int>(); + testing::StaticAssertTypeEq(config)), long>(); + testing::StaticAssertTypeEq(config)), + unsigned>(); + testing::StaticAssertTypeEq(config)), + long long int>(); +} + + +TEST(GetValue, RealType) +{ + double value = 1.0; + pnode config{value}; + + ASSERT_EQ(get_value(config), value); + ASSERT_EQ(get_value(config), value); + testing::StaticAssertTypeEq(config)), float>(); + testing::StaticAssertTypeEq(config)), double>(); +} + + +TEST(GetValue, ComplexType) +{ + double real = 1.0; + double imag = -1.0; + pnode config{real}; + pnode array_config{pnode::array_type{pnode{real}, pnode{imag}}}; + + // Only one value + ASSERT_EQ(get_value>(config), + std::complex(real)); + ASSERT_EQ(get_value>(config), + std::complex(real)); + testing::StaticAssertTypeEq>( + config)), + std::complex>(); + testing::StaticAssertTypeEq>( + config)), + std::complex>(); + // Two value [real, imag] + ASSERT_EQ(get_value>(array_config), + std::complex(real, imag)); + ASSERT_EQ(get_value>(array_config), + std::complex(real, imag)); + testing::StaticAssertTypeEq>( + array_config)), + std::complex>(); + testing::StaticAssertTypeEq>( + array_config)), + std::complex>(); +} + + +} // namespace diff --git a/core/test/config/registry.cpp b/core/test/config/registry.cpp new file mode 100644 index 00000000000..e6fc8eef671 --- /dev/null +++ b/core/test/config/registry.cpp @@ -0,0 +1,201 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + + +#include + + +#include +#include +#include +#include +#include +#include +#include + + +#include "core/config/config_helper.hpp" +#include "core/config/registry_accessor.hpp" +#include "core/test/utils.hpp" + + +using namespace gko::config; + + +class Registry : public ::testing::Test { +protected: + using Matrix = gko::matrix::Dense; + using Solver = gko::solver::Cg; + using Stop = gko::stop::Iteration; + + Registry() + : exec{gko::ReferenceExecutor::create()}, + matrix{Matrix::create(exec)}, + solver_factory{Solver::build().on(exec)}, + stop_factory{Stop::build().on(exec)}, + func{[](const pnode& config, const registry& context, + const type_descriptor& td_for_child) { + return gko::solver::Cg::build(); + }}, + reg{{{"func", func}}} + {} + + std::shared_ptr exec; + std::shared_ptr matrix; + std::shared_ptr solver_factory; + std::shared_ptr stop_factory; + std::function( + const pnode&, const registry&, type_descriptor)> + func; + registry reg; +}; + + +TEST_F(Registry, InsertData) +{ + { + SCOPED_TRACE("can put data"); + ASSERT_TRUE(reg.emplace("matrix", matrix)); + ASSERT_TRUE(reg.emplace("solver_factory", solver_factory)); + ASSERT_TRUE(reg.emplace("stop_factory", stop_factory)); + } + { + SCOPED_TRACE("do not insert the same key like normal map"); + ASSERT_FALSE(reg.emplace("matrix", matrix)); + ASSERT_FALSE(reg.emplace("solver_factory", solver_factory)); + ASSERT_FALSE(reg.emplace("stop_factory", stop_factory)); + } +} + + +TEST_F(Registry, SearchData) +{ + reg.emplace("matrix", matrix); + reg.emplace("solver_factory", solver_factory); + reg.emplace("stop_factory", stop_factory); + + auto found_matrix = + detail::registry_accessor::get_data(reg, "matrix"); + auto found_solver_factory = + detail::registry_accessor::get_data( + reg, "solver_factory"); + auto found_stop_factory = + detail::registry_accessor::get_data( + reg, "stop_factory"); + + // get correct ptrs + ASSERT_EQ(found_matrix, matrix); + ASSERT_EQ(found_solver_factory, solver_factory); + ASSERT_EQ(found_stop_factory, stop_factory); + // get correct types + testing::StaticAssertTypeEq>(); + testing::StaticAssertTypeEq>(); + testing::StaticAssertTypeEq>(); +} + + +TEST_F(Registry, SearchDataWithType) +{ + reg.emplace("matrix", matrix); + reg.emplace("solver_factory", solver_factory); + reg.emplace("stop_factory", stop_factory); + + auto found_matrix = + detail::registry_accessor::get_data(reg, "matrix"); + auto found_solver_factory = + detail::registry_accessor::get_data(reg, + "solver_factory"); + auto found_stop_factory = + detail::registry_accessor::get_data(reg, "stop_factory"); + + // get correct ptrs + ASSERT_EQ(found_matrix, matrix); + ASSERT_EQ(found_solver_factory, solver_factory); + ASSERT_EQ(found_stop_factory, stop_factory); + // get correct types + testing::StaticAssertTypeEq>(); + testing::StaticAssertTypeEq>(); + testing::StaticAssertTypeEq>(); +} + + +TEST_F(Registry, BuildFromConstructor) +{ + registry reg_obj{{{"matrix", matrix}, + {"solver_factory", solver_factory}, + {"stop_factory", stop_factory}}}; + + auto found_matrix = + detail::registry_accessor::get_data(reg_obj, "matrix"); + auto found_solver_factory = + detail::registry_accessor::get_data(reg_obj, + "solver_factory"); + auto found_stop_factory = + detail::registry_accessor::get_data(reg_obj, + "stop_factory"); + // get correct ptrs + ASSERT_EQ(found_matrix, matrix); + ASSERT_EQ(found_solver_factory, solver_factory); + ASSERT_EQ(found_stop_factory, stop_factory); + // get correct types + testing::StaticAssertTypeEq>(); + testing::StaticAssertTypeEq>(); + testing::StaticAssertTypeEq>(); +} + + +TEST_F(Registry, ThrowIfNotFound) +{ + ASSERT_THROW(detail::registry_accessor::get_data(reg, "N"), + std::out_of_range); + ASSERT_THROW( + detail::registry_accessor::get_data(reg, "N"), + std::out_of_range); + ASSERT_THROW( + detail::registry_accessor::get_data(reg, + "N"), + std::out_of_range); +} + + +TEST_F(Registry, ThrowWithWrongType) +{ + reg.emplace("matrix", matrix); + reg.emplace("solver_factory", solver_factory); + reg.emplace("stop_factory", stop_factory); + + ASSERT_THROW( + detail::registry_accessor::get_data>( + reg, "matrix"), + gko::NotSupported); + ASSERT_THROW( + detail::registry_accessor::get_data::Factory>( + reg, "solver_factory"), + gko::NotSupported); + ASSERT_THROW(detail::registry_accessor::get_data( + reg, "stop_factory"), + gko::NotSupported); +} + + +TEST_F(Registry, GetBuildMap) +{ + auto factory = + detail::registry_accessor::get_build_map(reg) + .at("func")(pnode{"unused"}, reg, type_descriptor{"void", "void"}) + .on(exec); + + ASSERT_NE(factory, nullptr); +} diff --git a/core/test/config/type_descriptor.cpp b/core/test/config/type_descriptor.cpp new file mode 100644 index 00000000000..a387ebe44b7 --- /dev/null +++ b/core/test/config/type_descriptor.cpp @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + + +#include + + +#include "core/config/type_descriptor_helper.hpp" +#include "core/test/utils.hpp" + + +using namespace gko::config; + + +TEST(TypeDescriptor, TemplateCreate) +{ + { + SCOPED_TRACE("default template"); + auto td = make_type_descriptor<>(); + + ASSERT_EQ(td.get_value_typestr(), "float64"); + ASSERT_EQ(td.get_index_typestr(), "int32"); + } + { + SCOPED_TRACE("specify valuetype"); + auto td = make_type_descriptor(); + + ASSERT_EQ(td.get_value_typestr(), "float32"); + ASSERT_EQ(td.get_index_typestr(), "int32"); + } + { + SCOPED_TRACE("specify all template"); + auto td = make_type_descriptor, gko::int64>(); + + ASSERT_EQ(td.get_value_typestr(), "complex"); + ASSERT_EQ(td.get_index_typestr(), "int64"); + } + { + SCOPED_TRACE("specify void"); + auto td = make_type_descriptor(); + + ASSERT_EQ(td.get_value_typestr(), "void"); + ASSERT_EQ(td.get_index_typestr(), "void"); + } +} + + +TEST(TypeDescriptor, Constructor) +{ + { + SCOPED_TRACE("default constructor"); + type_descriptor td; + + ASSERT_EQ(td.get_value_typestr(), "float64"); + ASSERT_EQ(td.get_index_typestr(), "int32"); + } + { + SCOPED_TRACE("specify valuetype"); + type_descriptor td("float32"); + + ASSERT_EQ(td.get_value_typestr(), "float32"); + ASSERT_EQ(td.get_index_typestr(), "int32"); + } + { + SCOPED_TRACE("specify all parameters"); + type_descriptor td("void", "int64"); + + ASSERT_EQ(td.get_value_typestr(), "void"); + ASSERT_EQ(td.get_index_typestr(), "int64"); + } +} diff --git a/include/ginkgo/core/config/config.hpp b/include/ginkgo/core/config/config.hpp new file mode 100644 index 00000000000..5d77c6f71c7 --- /dev/null +++ b/include/ginkgo/core/config/config.hpp @@ -0,0 +1,174 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_PUBLIC_CORE_CONFIG_CONFIG_HPP_ +#define GKO_PUBLIC_CORE_CONFIG_CONFIG_HPP_ + + +#include +#include +#include + + +#include +#include +#include +#include + + +namespace gko { +namespace config { + +class registry; + + +class pnode; + + +/** + * parse is the main entry point to create an Ginkgo LinOpFactory based on + * some file configuration. It reads a configuration stored as a property tree + * and creates the desired type. + * + * General rules for configuration + * 1. The configuration can be used to define factory parameters and class + * template parameters. Any factory parameter that is not defined in the + * configuration will fallback to their default value. Any template parameter + * that is not defined will fallback to the type_descriptor argument + * 2. The new `"type"` key determines which Ginkgo object to create. The value + * for this key is the desired class name with namespaces (except for + * `gko::`, `experimental::`, `stop::`). Any template parameters a class + * might have are left out. Only classes with a factory are supported. For + * example, the configuration `"type": "solver::Cg"` specifies that a Cg + * solver will be created. Note: template parameters can either be given in + * the configuration as separate key-value pairs, or in the type_descriptor. + * 3. Factory and class template parameters can be defined with key-value pairs + * that are derived from the class they are referring to. When a factory has + * a parameter with the function `with_(value)`, then the configuration + * allows `"": value` to define this parameter. When a class has a + * template parameter `template class`, then + * the configuration allows `"": value` to the template parameter. The + * supported values of the template parameter depend on the context. For + * index and value types, these are listed under 4. + * 4. Values for template parameters are represented with strings. The following + * datatypes, with postfix to indicate their size, are supported: int32, + * int64, float32, float64, complex, complex. + * 5. All keys use snake_case including template parameters. Factory parameter + * keys are already defined with snake_case in their factories, while class + * template arguments need to be translated, i.e. `ValueType -> value_type`. + * 6. The allowed values for factory parameters depend on the type the parameter + * is stored as. There are three distinct options: + * - POD types (bool, integer, floating point, or enum): Except for enum, + * the value has to be the POD type. For enums, a string value is used to + * represent them. The string has to be one of the possible enum values. + * An example of this type of parameter is the `krylov_dim` parameter for + * the Gmres solver. + * - LinOp (smart) pointers: The value has to be a string. The string is used + * to look up a LinOp object in the registry. + * An example is the `generated_preconditioner` parameter for iterative + * solvers such as Cg. + * - LinOpFactory (smart) pointers: The value can either be a string, or a + * nested configuration. The string has the same behavior as for LinOp + * pointers, i.e. an LinOpFactory object from the registry is taken. The + * nested configuration has to adhere to the general configuration rules + * again. See the examples below for some use-cases. + * An example is the `preconditioner` parameter for iterative solvers + * such as Cg. + * - CriterionFactory (smart) pointers: The value can either be a string, or + * a nested configuration. It has the same behavior as for LinOpFactory. + * - A vector of the types above: The value has to be an array with the + * inner values specified as above. + * 7. Complex values are represented as an 2-element array [real, imag]. If the + * array contains only one value, it will be considered as a complex number + * with an imaginary part = 0. An empty array will be treated as zero. + * 8. Keys that expect array of objects also accept single object which is + * interpreted as a 1-element array. This means the following configurations + * are equivalent if the key expects an array value: `"": [{object}]` + * and `"": {object}`. + * + * All configurations need to specify the resulting type by the field: + * ``` + * "type": "some_supported_ginkgo_type" + * ``` + * The result will be a deferred_factory_parameter, which is an intermediate + * step before a LinOpFactory. Providing an Executor through the function + * `.on(exec)` will then create the factory with the parameters as defined in + * the configuration. + * + * Given a configuration that is defined as + * ``` + * "type": "solver::Gmres", + * "krylov_dim": 20, + * "criteria": [ + * {"type": "Iteration", "max_iters": 10}, + * {"type": "ResidualNorm", "reduction_factor": 1e-6} + * ] + * ``` + * then passing it to this function like this: + * ```c++ + * auto gmres_factory = parse(config, context); + * ``` + * will create a factory for a GMRES solver, with the parameters `krylov_dim` + * set to 20, and a combined stopping criteria, consisting of an Iteration + * criteria with maximal 10 iterations, and a ResidualNorm criteria with a + * reduction factor of 1e-6. + * + * By default, the factory will use the value type double, and index type + * int32 when creating templated types. This can be changed by passing in a + * type_descriptor. For example: + * ```c++ + * auto gmres_factory = + * parse(config, context, + * make_type_descriptor()); + * ``` + * will lead to a GMRES solver that uses `float` as its value type. + * Additionally, the config can be used to set these types through the fields: + * ``` + * value_type: "some_value_type" + * ``` + * These types take precedence over the type descriptor and they are used for + * every created object beginning from the config level they are defined on and + * every deeper nested level, until a new type is defined. So, considering the + * following example + * ``` + * type: "solver::Ir", + * value_type: "float32" + * solver: { + * type: "solver::Gmres", + * preconditioner: { + * type: "preconditioner::Jacobi" + * value_type: "float64" + * } + * } + * ``` + * both the Ir and Gmres are using `float32` as a value type, and the + * Jacobi uses `float64`. + * + * @param config The property tree which must include `type` for the class + * base. + * @param context The registry which stores the building function map and the + * storage for generated objects. + * @param type_descriptor The default value and index type. If any object that + * is created as part of this configuration has a + * templated type, then the value and/or index type from + * the descriptor will be used. Any definition of the + * value and/or index type within the config will take + * precedence over the descriptor. If `void` is used for + * one or both of the types, then the corresponding type + * has to be defined in the config, otherwise the + * parsing will fail. + * + * @return a deferred_factory_parameter which creates an LinOpFactory after + * `.on(exec)` is called on it. + */ +deferred_factory_parameter parse( + const pnode& config, const registry& context, + const type_descriptor& td = make_type_descriptor<>()); + + +} // namespace config +} // namespace gko + + +#endif // GKO_PUBLIC_CORE_CONFIG_CONFIG_HPP_ diff --git a/include/ginkgo/core/config/property_tree.hpp b/include/ginkgo/core/config/property_tree.hpp index e1ef2f00dfb..2ddf42f5a27 100644 --- a/include/ginkgo/core/config/property_tree.hpp +++ b/include/ginkgo/core/config/property_tree.hpp @@ -25,7 +25,7 @@ namespace config { * A pnode can either be empty, hold a value (a string, integer, real, or bool), * contain an array of pnode., or contain a mapping between strings and pnodes. */ -class pnode { +class pnode final { public: using key_type = std::string; using map_type = std::map; diff --git a/include/ginkgo/core/config/registry.hpp b/include/ginkgo/core/config/registry.hpp new file mode 100644 index 00000000000..2efa160ee65 --- /dev/null +++ b/include/ginkgo/core/config/registry.hpp @@ -0,0 +1,260 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_PUBLIC_CORE_CONFIG_REGISTRY_HPP_ +#define GKO_PUBLIC_CORE_CONFIG_REGISTRY_HPP_ + + +#include +#include +#include +#include +#include +#include + + +#include +#include +#include +#include +#include +#include + + +namespace gko { +namespace config { + + +class registry; + +class type_descriptor; + +using configuration_map = + std::map( + const pnode&, const registry&, type_descriptor)>>; + + +namespace detail { + + +class registry_accessor; + + +/** + * base_type gives the base type according to given type. + * + * @tparam T the type + */ +template +struct base_type {}; + +template +struct base_type::value>> { + using type = LinOp; +}; + +template +struct base_type< + T, std::enable_if_t::value>> { + using type = LinOpFactory; +}; + +template +struct base_type< + T, + std::enable_if_t::value>> { + using type = stop::CriterionFactory; +}; + + +/** + * allowed_ptr is a type-erased object for LinOp/LinOpFactory/CriterionFactory + * shared_ptr. + */ +class allowed_ptr { +public: + /** + * The constructor accepts any shared pointer whose base type is LinOp, + * LinOpFactory, or CriterionFactory. We use a template rather than + * a constructor without a template because it allows the user to directly + * use uninitialized_list in the registry constructor without wrapping + * allowed_ptr manually. + */ + template + allowed_ptr(std::shared_ptr obj); + + /** + * Check whether it contains the Type data + * + * @tparam Type the checking type + */ + template + bool contains() const; + + /** + * Get the shared pointer with Type + * + * @tparam Type the desired type + * + * @return the shared pointer of Type + */ + template + std::shared_ptr get() const; + +private: + struct generic_container { + virtual ~generic_container() = default; + }; + + template + struct concrete_container : generic_container { + concrete_container(std::shared_ptr obj) : ptr{obj} + { + static_assert( + std::is_same::type>::value, + "The given type must be a base_type"); + } + + std::shared_ptr ptr; + }; + + std::shared_ptr data_; +}; + + +template +inline allowed_ptr::allowed_ptr(std::shared_ptr obj) +{ + data_ = + std::make_shared::type>>( + obj); +} + + +template +inline bool allowed_ptr::contains() const +{ + return dynamic_cast*>(data_.get()); +} + + +template +inline std::shared_ptr allowed_ptr::get() const +{ + GKO_THROW_IF_INVALID(this->template contains(), + "does not hold the requested type."); + return dynamic_cast*>(data_.get())->ptr; +} + + +} // namespace detail + + +/** + * This class stores additional context for creating Ginkgo objects from + * configuration files. + * + * The context can contain user-provided objects that derive from the following + * base types: + * - LinOp + * - LinOpFactory + * - CriterionFactory + * + * Additionally, users can provide mappings from a configuration (provided as + * a pnode) to user-defined types that are derived from LinOpFactory + */ +class registry final { +public: + friend class detail::registry_accessor; + + + /** + * registry constructor + * + * @param additional_map the additional map to dispatch the class base. + * Users can extend the map to fit their own + * LinOpFactory. We suggest using "usr::" as the + * prefix in the key to simply avoid conflict with + * ginkgo's map. + */ + registry(const configuration_map& additional_map = {}); + + /** + * registry constructor + * + * @param stored_map the map stores the shared pointer of users' objects. + * It can hold any type whose base type is + * LinOp/LinOpFactory/CriterionFactory. + * For example, + * ``` + * {{ + * {"csr", csr_shared_ptr}, + * {"cg", cg_shared_ptr} + * }} + * ``` + * @param additional_map the additional map to dispatch the class base. + * Users can extend the map to fit their own + * LinOpFactory. We suggest using "usr::" as the + * prefix in the key to simply avoid conflict with + * ginkgo's map. + */ + registry( + const std::unordered_map& stored_map, + const configuration_map& additional_map = {}); + + /** + * Store the data with the key. + * + * @tparam T the type + * + * @param key the unique key string + * @param data the shared pointer of the object + */ + template + bool emplace(std::string key, std::shared_ptr data); + +protected: + /** + * Search the key on the corresponding map. + * + * @tparam T the type + * + * @param key the key string + * + * @return the shared pointer of the object + */ + template + std::shared_ptr get_data(std::string key) const; + + /** + * Get the stored build map + */ + const configuration_map& get_build_map() const { return build_map_; } + +private: + std::unordered_map stored_map_; + configuration_map build_map_; +}; + + +template +inline bool registry::emplace(std::string key, std::shared_ptr data) +{ + auto it = stored_map_.emplace(key, data); + return it.second; +} + + +template +inline std::shared_ptr registry::get_data(std::string key) const +{ + return gko::as(stored_map_.at(key) + .template get::type>()); +} + +} // namespace config +} // namespace gko + +#endif // GKO_PUBLIC_CORE_CONFIG_REGISTRY_HPP_ diff --git a/include/ginkgo/core/config/type_descriptor.hpp b/include/ginkgo/core/config/type_descriptor.hpp new file mode 100644 index 00000000000..48475f7f469 --- /dev/null +++ b/include/ginkgo/core/config/type_descriptor.hpp @@ -0,0 +1,82 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_PUBLIC_CORE_CONFIG_TYPE_DESCRIPTOR_HPP_ +#define GKO_PUBLIC_CORE_CONFIG_TYPE_DESCRIPTOR_HPP_ + + +#include + +namespace gko { +namespace config { + + +/** + * This class describes the value and index types to be used when building a + * Ginkgo type from a configuration file. + * + * A type_descriptor is passed in order to define the parse function defines + * which template parameters, in terms of value_type and/or index_type, the + * created object will have. For example, a CG solver created like this: + * ``` + * auto cg = parse(config, context, type_descriptor("float64", "int32")); + * ``` + * will have the value type `float64` and the index type `int32`. Any Ginkgo + * object that does not require one of these types will just ignore it. The + * value `void` can be used to specify that no default type is provided. In this + * case, the configuration has to provide the necessary template types. + * + * If the configuration specifies one of the fields (or both): + * ``` + * value_type: "some_value_type" + * index_type: "some_index_type" + * ``` + * these types will take precedence over the type_descriptor. + */ +class type_descriptor final { +public: + /** + * type_descriptor constructor. There is free function + * `make_type_descriptor` to create the object by template. + * + * @param value_typestr the value type string. "void" means no default. + * @param index_typestr the index type string. "void" means no default. + * + * @note there is no way to call the constructor with explicit template, so + * we create another free function to handle it. + */ + explicit type_descriptor(std::string value_typestr = "float64", + std::string index_typestr = "int32"); + + /** + * Get the value type string. + */ + const std::string& get_value_typestr() const; + + /** + * Get the index type string + */ + const std::string& get_index_typestr() const; + +private: + std::string value_typestr_; + std::string index_typestr_; +}; + + +/** + * A helper function to properly set up the descriptor + * from template type directly. + * + * @tparam ValueType the value type in descriptor + * @tparam IndexType the index type in descriptor + */ +template +type_descriptor make_type_descriptor(); + + +} // namespace config +} // namespace gko + +#endif // GKO_PUBLIC_CORE_CONFIG_TYPE_DESCRIPTOR_HPP_ diff --git a/include/ginkgo/core/solver/cg.hpp b/include/ginkgo/core/solver/cg.hpp index a56e543d5ca..9302f2297b3 100644 --- a/include/ginkgo/core/solver/cg.hpp +++ b/include/ginkgo/core/solver/cg.hpp @@ -14,6 +14,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -73,6 +76,23 @@ class Cg : public EnableLinOp>, GKO_ENABLE_LIN_OP_FACTORY(Cg, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); + /** + * Create the parameters from the property_tree. + * Because this is directly tied to the specific type. The value/index type + * settings are ignored and type_descriptor is for children objects. + * + * @param config the property tree for setting + * @param context the registry + * @param td_for_child the type descriptor for children objects. The + * default will directly from the specific type. + * + * @return parameters + */ + static parameters_type parse(const config::pnode& config, + const config::registry& context, + const config::type_descriptor& td_for_child = + config::make_type_descriptor()); + protected: void apply_impl(const LinOp* b, LinOp* x) const override; diff --git a/include/ginkgo/ginkgo.hpp b/include/ginkgo/ginkgo.hpp index 854fb8c03da..f835eba9c26 100644 --- a/include/ginkgo/ginkgo.hpp +++ b/include/ginkgo/ginkgo.hpp @@ -53,7 +53,10 @@ #include #include +#include #include +#include +#include #include #include