diff --git a/core/config/config.cpp b/core/config/config.cpp index 8ff232f7af5..8c97c0038ed 100644 --- a/core/config/config.cpp +++ b/core/config/config.cpp @@ -24,12 +24,12 @@ deferred_factory_parameter parse(const pnode& config, const registry& context, const type_descriptor& td) { - if (auto& obj = config.get("Type")) { + if (auto& obj = config.get("type")) { auto func = detail::registry_accessor::get_build_map(context).at( obj.get_string()); return func(config, context, td); } - GKO_INVALID_STATE("Should contain Type property"); + GKO_INVALID_STATE("Should contain type property"); } diff --git a/core/config/config_helper.cpp b/core/config/config_helper.cpp index 89f7589e1ee..588aef61219 100644 --- a/core/config/config_helper.cpp +++ b/core/config/config_helper.cpp @@ -23,10 +23,10 @@ type_descriptor update_type(const pnode& config, const type_descriptor& td) auto value_typestr = td.get_value_typestr(); auto index_typestr = td.get_index_typestr(); - if (auto& obj = config.get("ValueType")) { + if (auto& obj = config.get("value_type")) { value_typestr = obj.get_string(); } - if (auto& obj = config.get("IndexType")) { + if (auto& obj = config.get("index_type")) { index_typestr = obj.get_string(); } return type_descriptor{value_typestr, index_typestr}; diff --git a/core/config/config_helper.hpp b/core/config/config_helper.hpp index 75018a04c0e..4b4fb85080a 100644 --- a/core/config/config_helper.hpp +++ b/core/config/config_helper.hpp @@ -42,19 +42,21 @@ deferred_factory_parameter parse( const pnode& config, const registry& context, const type_descriptor& td = make_type_descriptor<>()); -/** - * This function updates the default type setting from current config. Any type - * that is not specified in the config will fall back to the type stored in the - * current type_descriptor. - */ -type_descriptor update_type(const pnode& config, const type_descriptor& td); /** * get_stored_obj searches the object pointer stored in the registry by string */ template inline std::shared_ptr get_stored_obj(const pnode& config, - const registry& context); + const registry& context) +{ + std::shared_ptr ptr; + using T_non_const = std::remove_const_t; + ptr = detail::registry_accessor::get_data(context, + config.get_string()); + GKO_THROW_IF_INVALID(ptr.get() != nullptr, "Do not get the stored data"); + return ptr; +} /** @@ -85,55 +87,6 @@ get_factory(const pnode& config, /** * get_factory_vector will gives a vector of factory by calling get_factory. */ -template -inline std::vector> get_factory_vector( - const pnode& config, const registry& context, const type_descriptor& td); - - -/** - * get_value gets the corresponding type value from config. - * - * This is specialization for integral type - */ -template -inline - typename std::enable_if::value, IndexType>::type - get_value(const pnode& config); - -/** - * get_value gets the corresponding type value from config. - * - * This is specialization for floating point type - */ -template -inline typename std::enable_if::value, - ValueType>::type -get_value(const pnode& config); - -/** - * get_value gets the corresponding type value from config. - * - * This is specialization for complex type - */ -template -inline typename std::enable_if::value, - ValueType>::type -get_value(const pnode& config); - - -template -inline std::shared_ptr get_stored_obj(const pnode& config, - const registry& context) -{ - std::shared_ptr ptr; - using T_non_const = std::remove_const_t; - ptr = detail::registry_accessor::get_data(context, - config.get_string()); - GKO_THROW_IF_INVALID(ptr.get() != nullptr, "Do not get the stored data"); - return ptr; -} - - template inline std::vector> get_factory_vector( const pnode& config, const registry& context, const type_descriptor& td) @@ -152,6 +105,11 @@ inline std::vector> get_factory_vector( } +/** + * get_value gets the corresponding type value from config. + * + * This is specialization for integral type + */ template inline typename std::enable_if::value, IndexType>::type @@ -165,6 +123,12 @@ inline return static_cast(val); } + +/** + * get_value gets the corresponding type value from config. + * + * This is specialization for floating point type + */ template inline typename std::enable_if::value, ValueType>::type @@ -179,6 +143,11 @@ get_value(const pnode& config) return static_cast(val); } +/** + * get_value gets the corresponding type value from config. + * + * This is specialization for complex type + */ template inline typename std::enable_if::value, ValueType>::type diff --git a/core/config/dispatch.hpp b/core/config/dispatch.hpp index 25b7af425c0..c765150f72a 100644 --- a/core/config/dispatch.hpp +++ b/core/config/dispatch.hpp @@ -76,6 +76,10 @@ deferred_factory_parameter dispatch(const pnode& config, } /** + * This function is to dispatch the type from runtime parameter. + * The concrete class need to have static member function + * `parse(pnode, registry, type_descriptor)` + * * @param config the configuration * @param context the registry context * @param td the default type descriptor diff --git a/core/config/stop_config.cpp b/core/config/stop_config.cpp index fdfdaab8e0b..55670377b27 100644 --- a/core/config/stop_config.cpp +++ b/core/config/stop_config.cpp @@ -138,11 +138,11 @@ get_factory(const pnode& config, gko::stop::CriterionFactory>( const pnode&, const registry&, type_descriptor)>> criterion_map{ - {{"Time", configure_time}, - {"Iteration", configure_iter}, - {"ResidualNorm", configure_residual}, - {"ImplicitResidualNorm", configure_implicit_residual}}}; - return criterion_map.at(config.get("Type").get_string())(config, + {{"stop::Time", configure_time}, + {"stop::Iteration", configure_iter}, + {"stop::ResidualNorm", configure_residual}, + {"stop::ImplicitResidualNorm", configure_implicit_residual}}}; + return criterion_map.at(config.get("type").get_string())(config, context, td); } GKO_THROW_IF_INVALID(!ptr.is_empty(), "Parse get nullptr in the end"); diff --git a/core/config/type_descriptor_helper.hpp b/core/config/type_descriptor_helper.hpp index 261250ab9ed..1a4ca1ac613 100644 --- a/core/config/type_descriptor_helper.hpp +++ b/core/config/type_descriptor_helper.hpp @@ -42,11 +42,11 @@ struct type_string {}; } TYPE_STRING_OVERLOAD(void, "void"); -TYPE_STRING_OVERLOAD(double, "double"); -TYPE_STRING_OVERLOAD(float, "float"); -TYPE_STRING_OVERLOAD(std::complex, "complex"); -TYPE_STRING_OVERLOAD(std::complex, "complex"); -TYPE_STRING_OVERLOAD(int32, "int"); +TYPE_STRING_OVERLOAD(double, "float64"); +TYPE_STRING_OVERLOAD(float, "float32"); +TYPE_STRING_OVERLOAD(std::complex, "complex"); +TYPE_STRING_OVERLOAD(std::complex, "complex"); +TYPE_STRING_OVERLOAD(int32, "int32"); TYPE_STRING_OVERLOAD(int64, "int64"); diff --git a/core/test/config/config.cpp b/core/test/config/config.cpp index 815ca674770..a791667b8e2 100644 --- a/core/test/config/config.cpp +++ b/core/test/config/config.cpp @@ -38,7 +38,8 @@ class Config : public ::testing::Test { : exec(gko::ReferenceExecutor::create()), mtx(gko::initialize( {{2, -1.0, 0.0}, {-1.0, 2, -1.0}, {0.0, -1.0, 2}}, exec)), - stop_config({{"Type", pnode{"Iteration"}}, {"max_iters", pnode{1}}}) + stop_config( + {{"type", pnode{"stop::Iteration"}}, {"max_iters", pnode{1}}}) {} std::shared_ptr exec; @@ -51,7 +52,8 @@ TEST_F(Config, GenerateObjectWithoutDefault) { auto reg = registry(); - pnode p{{{"ValueType", pnode{"double"}}, {"criteria", this->stop_config}}}; + pnode p{ + {{"value_type", pnode{"float64"}}, {"criteria", this->stop_config}}}; auto obj = parse(p, reg).on(this->exec); ASSERT_NE(dynamic_cast::Factory*>(obj.get()), @@ -67,7 +69,7 @@ TEST_F(Config, GenerateObjectWithData) pnode p{{{"generated_preconditioner", pnode{"precond"}}, {"criteria", this->stop_config}}}; auto obj = - parse(p, reg, type_descriptor{"float", "void"}) + parse(p, reg, type_descriptor{"float32", "void"}) .on(this->exec); ASSERT_NE(dynamic_cast::Factory*>(obj.get()), @@ -83,8 +85,8 @@ TEST_F(Config, GenerateObjectWithPreconditioner) { auto reg = registry(); auto precond_node = - pnode{{{"Type", pnode{"solver::Cg"}}, {"criteria", this->stop_config}}}; - pnode p{{{"ValueType", pnode{"double"}}, + pnode{{{"type", pnode{"solver::Cg"}}, {"criteria", this->stop_config}}}; + pnode p{{{"value_type", pnode{"float64"}}, {"criteria", this->stop_config}, {"preconditioner", precond_node}}}; @@ -102,7 +104,6 @@ TEST_F(Config, GenerateObjectWithPreconditioner) TEST_F(Config, GenerateObjectWithCustomBuild) { configuration_map config_map; - config_map["Custom"] = [](const pnode& config, const registry& context, const type_descriptor& td_for_child) { return gko::solver::Bicg::build().with_criteria( @@ -110,13 +111,13 @@ TEST_F(Config, GenerateObjectWithCustomBuild) }; auto reg = registry(config_map); auto precond_node = - pnode{std::map{{"Type", pnode{"Custom"}}}}; - pnode p{{{"ValueType", pnode{"double"}}, + pnode{std::map{{"type", pnode{"Custom"}}}}; + pnode p{{{"value_type", pnode{"float64"}}, {"criteria", this->stop_config}, {"preconditioner", precond_node}}}; auto obj = - parse(p, reg, type_descriptor{"double", "void"}) + parse(p, reg, type_descriptor{"float64", "void"}) .on(this->exec); ASSERT_NE(dynamic_cast::Factory*>(obj.get()), diff --git a/core/test/config/type_descriptor.cpp b/core/test/config/type_descriptor.cpp index d45cb47e730..c7b24980d63 100644 --- a/core/test/config/type_descriptor.cpp +++ b/core/test/config/type_descriptor.cpp @@ -21,21 +21,21 @@ TEST(TypeDescriptor, TemplateCreate) SCOPED_TRACE("defaule template"); auto td = make_type_descriptor<>(); - ASSERT_EQ(td.get_value_typestr(), "double"); - ASSERT_EQ(td.get_index_typestr(), "int"); + ASSERT_EQ(td.get_value_typestr(), "float64"); + ASSERT_EQ(td.get_index_typestr(), "int32"); } { SCOPED_TRACE("specify valuetype"); auto td = make_type_descriptor(); - ASSERT_EQ(td.get_value_typestr(), "float"); - ASSERT_EQ(td.get_index_typestr(), "int"); + ASSERT_EQ(td.get_value_typestr(), "float32"); + ASSERT_EQ(td.get_index_typestr(), "int32"); } { SCOPED_TRACE("specify all template"); auto td = make_type_descriptor, gko::int64>(); - ASSERT_EQ(td.get_value_typestr(), "complex"); + ASSERT_EQ(td.get_value_typestr(), "complex"); ASSERT_EQ(td.get_index_typestr(), "int64"); } { @@ -54,15 +54,15 @@ TEST(TypeDescriptor, Constructor) SCOPED_TRACE("defaule constructor"); type_descriptor td; - ASSERT_EQ(td.get_value_typestr(), "double"); - ASSERT_EQ(td.get_index_typestr(), "int"); + ASSERT_EQ(td.get_value_typestr(), "float64"); + ASSERT_EQ(td.get_index_typestr(), "int32"); } { SCOPED_TRACE("specify valuetype"); - type_descriptor td("float"); + type_descriptor td("float32"); - ASSERT_EQ(td.get_value_typestr(), "float"); - ASSERT_EQ(td.get_index_typestr(), "int"); + ASSERT_EQ(td.get_value_typestr(), "float32"); + ASSERT_EQ(td.get_index_typestr(), "int32"); } { SCOPED_TRACE("specify all parameters"); diff --git a/include/ginkgo/core/config/config.hpp b/include/ginkgo/core/config/config.hpp index 04d1528f637..1e7d12992c4 100644 --- a/include/ginkgo/core/config/config.hpp +++ b/include/ginkgo/core/config/config.hpp @@ -31,27 +31,53 @@ class pnode; * some file configuration. It reads a configuration stored as a property tree * and creates the desired type. * - * The configuration needs to specify the resulting type by the field: + * General rules for configuration + * 1. all parameter and template usage are according to the class directly. It + * has the same behavior as the class like default setting without specifying + * anything. When the class factory parameters allows `with_(value)`, + * the file configuration will allow `"": value` + * 2. all key will use snake_case including template. For example, ValueType -> + * value_type + * 3. If the value is not bool, integer, or floating point, we will use string + * to represent everything. For example, we will use string to select the + * enum value. `"baseline": "absolute"` will select the absolute baseline in + * ResidualNorm critrion + * 4. `"type"` is the new key to select the class without template type. We also + * prepand the namespace except for gko. For example, we use "solver::Cg" for + * Cg solver. Note. the template type is given by the another entry or from + * the type_descriptor. + * 5. the data type uses fixed-width representation + * void, int32, int64, float32, float64, complex, complex + * 6. We use [real, imag] to represent complex value. If it only contains one + * value or without [], we will treat it as the complex number with imaginary + * part = 0. + * 7. In many cases, the parameter allows vector input and we handle it by array + * of property tree. If the array only contains one object, users can + * directly provide the object without putting it into an array. + * `"criteria": [{...}]` and `"criteria": {...}` are the same. + * + * The configuration needs + * to specify the resulting type by the field: * ``` - * type: "some_supported_ginkgo_type" + * "type": "some_supported_ginkgo_type" * ``` - * The result will be a deferred_factory_parameter, which can be thought of as - * an intermediate step before a LinOpFactory. Providing the result an Executor - * through the function `.on(exec)` will then create the factory with the - * parameters as defined in the configuration. + * The result will be a deferred_factory_parameter, + * which can be thought of as an intermediate step before a LinOpFactory. + * Providing the result an Executor through the function `.on(exec)` will then + * create the factory with the parameters as defined in the configuration. * * Given a configuration that is defined as * ``` - * type: "solver::Gmres", - * krylov_dim: 20, - * stop: [ - * {iteration: 10}, - * {residual_norm: 1e-6} + * "type": "solver::Gmres", + * "krylov_dim": 20, + * "criteria": [ + * {"type": "stop::Iteration", "max_iters": 10}, + * {"type": "stop::ResidualNorm", "reduction_factor": 1e-6} * ] * ``` * then passing it to this function like this: * ```c++ - * auto gmres_factory = build_from_config(config, context); + * auto gmres_factory = parse(config, context); * ``` * will create a factory for a GMRES solver, with the parameters `krylov_dim` * set to 20, and a combined stopping criteria, consisting of an Iteration @@ -62,8 +88,9 @@ class pnode; * int32 when creating templated types. This can be changed by passing in a * type_descriptor. For example: * ```c++ - * auto gmres_factory = build_from_config(config, context, - * make_type_descriptor()); + * auto gmres_factory = + * build_from_config(config, context, + * make_type_descriptor()); * ``` * will lead to a GMRES solver that uses `float` as its value type. * Additionally, the config can be used to set these types through the fields: @@ -89,15 +116,16 @@ class pnode; * both the Ir and Gmres are using `float32` as a value type, and the * Jacobi uses `float64`. * - * @param config The property tree which must include `Type` for the class - * base. + * @param config The property tree which must include `type` for the class + * base. * @param context The registry which stores the building function map and the - * storage for generated objects. + * storage for generated objects. * @param type_descriptor The default value and index type. If any object that - * is created as part of this configuration has a templated type, then the value - * and/or index type from the descriptor will be used. Any definition of the - * value and/or index type within the config will take precedence over the - * descriptor. + * is created as part of this configuration has a + * templated type, then the value and/or index type from + * the descriptor will be used. Any definition of the + * value and/or index type within the config will take + * precedence over the descriptor. * * @return a deferred_factory_parameter which creates an LinOpFactory after * `.on(exec)` is called on it. diff --git a/include/ginkgo/core/config/property_tree.hpp b/include/ginkgo/core/config/property_tree.hpp index e1ef2f00dfb..2ddf42f5a27 100644 --- a/include/ginkgo/core/config/property_tree.hpp +++ b/include/ginkgo/core/config/property_tree.hpp @@ -25,7 +25,7 @@ namespace config { * A pnode can either be empty, hold a value (a string, integer, real, or bool), * contain an array of pnode., or contain a mapping between strings and pnodes. */ -class pnode { +class pnode final { public: using key_type = std::string; using map_type = std::map; diff --git a/include/ginkgo/core/config/registry.hpp b/include/ginkgo/core/config/registry.hpp index 8a4958bd64a..fc2b4151089 100644 --- a/include/ginkgo/core/config/registry.hpp +++ b/include/ginkgo/core/config/registry.hpp @@ -159,7 +159,7 @@ inline std::shared_ptr allowed_ptr::get() const * Additionally, users can provide mappings from a configuration (provided as * a pnode) to user-defined types that are derived from LinOpFactory */ -class registry { +class registry final { public: friend class detail::registry_accessor; diff --git a/include/ginkgo/core/config/type_descriptor.hpp b/include/ginkgo/core/config/type_descriptor.hpp index 7167a7650cd..3a05ee33aac 100644 --- a/include/ginkgo/core/config/type_descriptor.hpp +++ b/include/ginkgo/core/config/type_descriptor.hpp @@ -33,7 +33,7 @@ namespace config { * ``` * these types will take precedence over the type_descriptor. */ -class type_descriptor { +class type_descriptor final { public: /** * type_descriptor constructor. There is free function @@ -64,6 +64,13 @@ class type_descriptor { }; +/** + * make_type_descriptor is a helper function to properly set up the descriptor + * from template type directly. + * + * @tparam ValueType the value type in descriptor + * @tparam IndexType the index type in descriptor + */ template type_descriptor make_type_descriptor();