diff --git a/core/config/config_helper.cpp b/core/config/config_helper.cpp index 4423080866c..738c0a31539 100644 --- a/core/config/config_helper.cpp +++ b/core/config/config_helper.cpp @@ -19,8 +19,10 @@ namespace config { template <> -deferred_factory_parameter get_factory( - const pnode& config, const registry& context, const type_descriptor& td) +deferred_factory_parameter +build_or_get_factory(const pnode& config, + const registry& context, + const type_descriptor& td) { deferred_factory_parameter ptr; if (config.get_tag() == pnode::tag_t::string) { @@ -31,7 +33,7 @@ deferred_factory_parameter get_factory( } else { GKO_INVALID_STATE("The data of config is not valid."); } - GKO_THROW_IF_INVALID(!ptr.is_empty(), "Parse get nullptr in the end"); + GKO_THROW_IF_INVALID(!ptr.is_empty(), "parse returned nullptr"); return ptr; } diff --git a/core/config/config_helper.hpp b/core/config/config_helper.hpp index 19aadda71e7..6efd6001a5a 100644 --- a/core/config/config_helper.hpp +++ b/core/config/config_helper.hpp @@ -60,45 +60,47 @@ inline std::shared_ptr get_stored_obj(const pnode& config, /** - * get_factory builds the factory from config (map) or searches the pointers in + * Build the factory from config (map) or search the pointers in * the registry by string. */ template -deferred_factory_parameter get_factory(const pnode& config, - const registry& context, - const type_descriptor& td); +deferred_factory_parameter build_or_get_factory(const pnode& config, + const registry& context, + const type_descriptor& td); /** * specialize for const LinOpFactory */ template <> -deferred_factory_parameter get_factory( - const pnode& config, const registry& context, const type_descriptor& td); +deferred_factory_parameter +build_or_get_factory(const pnode& config, + const registry& context, + const type_descriptor& td); /** * specialize for const stop::CriterionFactory */ template <> deferred_factory_parameter -get_factory(const pnode& config, - const registry& context, - const type_descriptor& td); +build_or_get_factory(const pnode& config, + const registry& context, + const type_descriptor& td); /** - * get_factory_vector will gives a vector of factory by calling get_factory. + * give a vector of factory by calling build_or_get_factory. */ template -inline std::vector> get_factory_vector( +inline std::vector> build_or_get_factory_vector( const pnode& config, const registry& context, const type_descriptor& td) { std::vector> res; if (config.get_tag() == pnode::tag_t::array) { for (const auto& it : config.get_array()) { - res.push_back(get_factory(it, context, td)); + res.push_back(build_or_get_factory(it, context, td)); } } else { // only one config can be passed without array - res.push_back(get_factory(config, context, td)); + res.push_back(build_or_get_factory(config, context, td)); } return res; @@ -111,9 +113,8 @@ inline std::vector> get_factory_vector( * This is specialization for integral type */ template -inline - typename std::enable_if::value, IndexType>::type - get_value(const pnode& config) +inline std::enable_if_t::value, IndexType> +get_value(const pnode& config) { auto val = config.get_integer(); GKO_THROW_IF_INVALID( @@ -130,8 +131,7 @@ inline * This is specialization for floating point type */ template -inline typename std::enable_if::value, - ValueType>::type +inline std::enable_if::value, ValueType> get_value(const pnode& config) { auto val = config.get_real(); @@ -149,9 +149,8 @@ get_value(const pnode& config) * This is specialization for complex type */ template -inline typename std::enable_if::value, - ValueType>::type -get_value(const pnode& config) +inline std::enable_if::value, ValueType> get_value( + const pnode& config) { using real_type = gko::remove_complex; if (config.get_tag() == pnode::tag_t::real) { diff --git a/core/config/stop_config.cpp b/core/config/stop_config.cpp index 55670377b27..e3e2e7ad57b 100644 --- a/core/config/stop_config.cpp +++ b/core/config/stop_config.cpp @@ -124,9 +124,9 @@ configure_implicit_residual(const pnode& config, const registry& context, template <> deferred_factory_parameter -get_factory(const pnode& config, - const registry& context, - const type_descriptor& td) +build_or_get_factory(const pnode& config, + const registry& context, + const type_descriptor& td) { deferred_factory_parameter ptr; if (config.get_tag() == pnode::tag_t::string) { @@ -145,7 +145,7 @@ get_factory(const pnode& config, return criterion_map.at(config.get("type").get_string())(config, context, td); } - GKO_THROW_IF_INVALID(!ptr.is_empty(), "Parse get nullptr in the end"); + GKO_THROW_IF_INVALID(!ptr.is_empty(), "parse returned nullptr"); return ptr; } diff --git a/core/config/type_descriptor_helper.hpp b/core/config/type_descriptor_helper.hpp index 1a4ca1ac613..3917e317773 100644 --- a/core/config/type_descriptor_helper.hpp +++ b/core/config/type_descriptor_helper.hpp @@ -7,16 +7,11 @@ #include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include namespace gko { @@ -49,6 +44,8 @@ TYPE_STRING_OVERLOAD(std::complex, "complex"); TYPE_STRING_OVERLOAD(int32, "int32"); TYPE_STRING_OVERLOAD(int64, "int64"); +#undef TYPE_STRING_OVERLOAD + } // namespace config } // namespace gko diff --git a/core/solver/cg.cpp b/core/solver/cg.cpp index 5452fe2bb7a..7c1e81c8971 100644 --- a/core/solver/cg.cpp +++ b/core/solver/cg.cpp @@ -69,12 +69,13 @@ typename Cg::parameters_type Cg::parse( } if (auto& obj = config.get("criteria")) { params.with_criteria( - gko::config::get_factory_vector( - obj, context, td_for_child)); + gko::config::build_or_get_factory_vector< + const stop::CriterionFactory>(obj, context, td_for_child)); } if (auto& obj = config.get("preconditioner")) { - params.with_preconditioner(gko::config::get_factory( - obj, context, td_for_child)); + params.with_preconditioner( + gko::config::build_or_get_factory( + obj, context, td_for_child)); } return params; } diff --git a/include/ginkgo/core/config/config.hpp b/include/ginkgo/core/config/config.hpp index 0d3d75b5de8..9fe2539d95f 100644 --- a/include/ginkgo/core/config/config.hpp +++ b/include/ginkgo/core/config/config.hpp @@ -46,8 +46,8 @@ class pnode; * prepend the namespace except for gko. For example, we use "solver::Cg" for * Cg solver. Note. the template type is given by the another entry or from * the type_descriptor. - * 5. the data type uses fixed-width representation - * int32, int64, float32, float64, complex, complex. + * 5. We have supports the following datatype with postfix to indicate their + * size: int32, int64, float32, float64, complex, complex. * note: we have also allow `void` additionally in type_descriptor to specify * file must contain the value/index type config. * 6. We use [real, imag] to represent complex values. If it only contains one diff --git a/include/ginkgo/core/config/registry.hpp b/include/ginkgo/core/config/registry.hpp index b4c1658c037..1cb7e40cf80 100644 --- a/include/ginkgo/core/config/registry.hpp +++ b/include/ginkgo/core/config/registry.hpp @@ -51,20 +51,20 @@ template struct base_type {}; template -struct base_type< - T, typename std::enable_if::value>::type> { +struct base_type::value>> { using type = LinOp; }; template -struct base_type::value>::type> { +struct base_type< + T, std::enable_if_t::value>> { using type = LinOpFactory; }; template -struct base_type::value>::type> { +struct base_type< + T, + std::enable_if_t::value>> { using type = stop::CriterionFactory; }; @@ -76,11 +76,11 @@ struct base_type allowed_ptr(std::shared_ptr obj); @@ -110,7 +110,12 @@ class allowed_ptr { template struct concrete_container : generic_container { - concrete_container(std::shared_ptr obj) : ptr{obj} {} + concrete_container(std::shared_ptr obj) : ptr{obj} + { + static_assert( + std::is_same::type>::value, + "The given type must be a base_type"); + } std::shared_ptr ptr; }; @@ -151,7 +156,8 @@ inline std::shared_ptr allowed_ptr::get() const * This class stores additional context for creating Ginkgo objects from * configuration files. * - * The context can contain user provided objects of the following types: + * The context can contain user-provided objects that derive from the following + * base types: * - LinOp * - LinOpFactory * - CriterionFactory @@ -168,7 +174,7 @@ class registry final { * registry constructor * * @param additional_map the additional map to dispatch the class base. - * Users can extend map to fit their own + * Users can extend the map to fit their own * LinOpFactory. */ registry(const configuration_map& additional_map = {}); @@ -187,7 +193,7 @@ class registry final { * }} * ``` * @param additional_map the additional map to dispatch the class base. - * Users can extend map to fit their own + * Users can extend the map to fit their own * LinOpFactory. */ registry( @@ -195,7 +201,7 @@ class registry final { const configuration_map& additional_map = {}); /** - * insert_data stores the data with the key. + * Store the data with the key. * * @tparam T the type * @@ -207,7 +213,7 @@ class registry final { protected: /** - * get_data searches the key on the corresponding map. + * Search the key on the corresponding map. * * @tparam T the type * @@ -219,7 +225,7 @@ class registry final { std::shared_ptr get_data(std::string key) const; /** - * get the stored build map + * Get the stored build map */ const configuration_map& get_build_map() const { return build_map_; } diff --git a/include/ginkgo/core/config/type_descriptor.hpp b/include/ginkgo/core/config/type_descriptor.hpp index 3a05ee33aac..244457e989e 100644 --- a/include/ginkgo/core/config/type_descriptor.hpp +++ b/include/ginkgo/core/config/type_descriptor.hpp @@ -26,7 +26,7 @@ namespace config { * object that does not require one of these types will just ignore it. We used * void type to specify no default type. * - * If the configurations specifies one of the fields (or both): + * If the configuration specifies one of the fields (or both): * ``` * value_type: "some_value_type" * index_type: "some_index_type" @@ -65,7 +65,7 @@ class type_descriptor final { /** - * make_type_descriptor is a helper function to properly set up the descriptor + * A helper function to properly set up the descriptor * from template type directly. * * @tparam ValueType the value type in descriptor