diff --git a/core/config/config.hpp b/core/config/config.hpp index 31b7def40b4..b13f4242c36 100644 --- a/core/config/config.hpp +++ b/core/config/config.hpp @@ -38,9 +38,12 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include +#include #include +#include #include #include @@ -131,9 +134,10 @@ get_pointer_vector( std::shared_ptr exec, type_descriptor td); -template ::value>::type> -inline IndexType get_value(const pnode& config) +template +inline + typename std::enable_if::value, IndexType>::type + get_value(const pnode& config) { auto val = config.get_data(); assert(val <= std::numeric_limits::max() && @@ -141,6 +145,32 @@ inline IndexType get_value(const pnode& config) return static_cast(val); } +template +inline typename std::enable_if::value, + ValueType>::type +get_value(const pnode& config) +{ + auto val = config.get_data(); + assert(val <= std::numeric_limits::max() && + val >= std::numeric_limits::min()); + return static_cast(val); +} + +template +inline typename std::enable_if::value, + ValueType>::type +get_value(const pnode& config) +{ + using real_type = gko::remove_complex; + if (config.is(pnode::status_t::object)) { + return static_cast(get_value(config)); + } else if (config.is(pnode::status_t::array)) { + return ValueType{get_value(config.at(0)), + get_value(config.at(1))}; + } + GKO_INVALID_STATE("Can not get complex value"); +} + #define SET_POINTER(_factory, _param_type, _param_name, _config, _context, \ _exec, _td) \ diff --git a/core/config/stop_config.cpp b/core/config/stop_config.cpp index 5d5774585f2..c92dbe0b5e7 100644 --- a/core/config/stop_config.cpp +++ b/core/config/stop_config.cpp @@ -90,8 +90,7 @@ class ResidualNormConfigurer { gko::config::type_descriptor td_for_child) { auto factory = stop::ResidualNorm::build(); - // SET_VALUE(factory, remove_complex, reduction_factor, - // config); + SET_VALUE(factory, remove_complex, reduction_factor, config); if (config.contains("baseline")) { factory.with_baseline( get_mode(config.at("baseline").get_data())); @@ -122,8 +121,7 @@ class ImplicitResidualNormConfigurer { gko::config::type_descriptor td_for_child) { auto factory = stop::ImplicitResidualNorm::build(); - // SET_VALUE(factory, remove_complex, reduction_factor, - // config); + SET_VALUE(factory, remove_complex, reduction_factor, config); if (config.contains("baseline")) { factory.with_baseline( get_mode(config.at("baseline").get_data())); diff --git a/core/test/config/config.cpp b/core/test/config/config.cpp index 3cd98d8af4b..6051a4812c2 100644 --- a/core/test/config/config.cpp +++ b/core/test/config/config.cpp @@ -48,16 +48,17 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/config/config.hpp" #include "core/test/config/utils.hpp" #include "core/test/utils.hpp" + namespace { -using gko::config::pnode; +using namespace gko::config; -template class Config : public ::testing::Test { protected: using value_type = double; @@ -74,38 +75,32 @@ class Config : public ::testing::Test { pnode stop_config; }; -TYPED_TEST_SUITE(Config, gko::test::ValueTypes, TypenameNameGenerator); +TEST_F(Config, GenerateMap) { ASSERT_NO_THROW(generate_config_map()); } -TYPED_TEST(Config, GenerateMap) -{ - ASSERT_NO_THROW(gko::config::generate_config_map()); -} - -TYPED_TEST(Config, GenerateObjectWithoutDefault) +TEST_F(Config, GenerateObjectWithoutDefault) { - auto config_map = gko::config::generate_config_map(); - auto reg = gko::config::registry(config_map); + auto config_map = generate_config_map(); + auto reg = registry(config_map); pnode p{{{"ValueType", pnode{"double"}}, {"criteria", this->stop_config}}}; - auto obj = gko::config::build_from_config<0>(p, reg, this->exec); + auto obj = build_from_config<0>(p, reg, this->exec); ASSERT_NE(dynamic_cast::Factory*>(obj.get()), nullptr); } -TYPED_TEST(Config, GenerateObjectWithData) +TEST_F(Config, GenerateObjectWithData) { - auto config_map = gko::config::generate_config_map(); - auto reg = gko::config::registry(config_map); + auto config_map = generate_config_map(); + auto reg = registry(config_map); reg.emplace("precond", this->mtx); pnode p{{{"generated_preconditioner", pnode{"precond"}}, {"criteria", this->stop_config}}}; - auto obj = - gko::config::build_from_config<0>(p, reg, this->exec, {"float", ""}); + auto obj = build_from_config<0>(p, reg, this->exec, {"float", ""}); ASSERT_NE(dynamic_cast::Factory*>(obj.get()), nullptr); @@ -116,15 +111,15 @@ TYPED_TEST(Config, GenerateObjectWithData) } -TYPED_TEST(Config, GenerateObjectWithPreconditioner) +TEST_F(Config, GenerateObjectWithPreconditioner) { - auto config_map = gko::config::generate_config_map(); - auto reg = gko::config::registry(config_map); + auto config_map = generate_config_map(); + auto reg = registry(config_map); pnode p{{{"ValueType", pnode{"double"}}, {"criteria", this->stop_config}}}; p.get_list()["preconditioner"] = pnode{{{"Type", pnode{"Cg"}}, {"criteria", this->stop_config}}}; - auto obj = gko::config::build_from_config<0>(p, reg, this->exec); + auto obj = build_from_config<0>(p, reg, this->exec); ASSERT_NE(dynamic_cast::Factory*>(obj.get()), nullptr); @@ -135,25 +130,23 @@ TYPED_TEST(Config, GenerateObjectWithPreconditioner) } -TYPED_TEST(Config, GenerateObjectWithCustomBuild) +TEST_F(Config, GenerateObjectWithCustomBuild) { - auto config_map = gko::config::generate_config_map(); + auto config_map = generate_config_map(); - config_map["Custom"] = [](const gko::config::pnode& config, - const gko::config::registry& context, + config_map["Custom"] = [](const pnode& config, const registry& context, std::shared_ptr& exec, - gko::config::type_descriptor td_for_child) { + type_descriptor td_for_child) { return gko::solver::Bicg::build() .with_criteria( gko::stop::Iteration::build().with_max_iters(2u).on(exec)) .on(exec); }; - auto reg = gko::config::registry(config_map); + auto reg = registry(config_map); pnode p{{{"ValueType", pnode{"double"}}, {"criteria", this->stop_config}}}; p.get_list()["preconditioner"] = pnode{{{"Type", pnode{"Custom"}}}}; - auto obj = - gko::config::build_from_config<0>(p, reg, this->exec, {"double", ""}); + auto obj = build_from_config<0>(p, reg, this->exec, {"double", ""}); ASSERT_NE(dynamic_cast::Factory*>(obj.get()), nullptr); @@ -165,4 +158,61 @@ TYPED_TEST(Config, GenerateObjectWithCustomBuild) } +TEST(GetValue, IndexType) +{ + long long int value = 123; + pnode config{value}; + + ASSERT_EQ(get_value(config), value); + ASSERT_EQ(get_value(config), value); + ASSERT_EQ(get_value(config), value); + ASSERT_EQ(get_value(config), value); + ASSERT_EQ(typeid(get_value(config)), typeid(int)); + ASSERT_EQ(typeid(get_value(config)), typeid(long)); + ASSERT_EQ(typeid(get_value(config)), typeid(unsigned)); + ASSERT_EQ(typeid(get_value(config)), typeid(long long int)); +} + + +TEST(GetValue, RealType) +{ + double value = 1.0; + pnode config{value}; + + ASSERT_EQ(get_value(config), value); + ASSERT_EQ(get_value(config), value); + ASSERT_EQ(typeid(get_value(config)), typeid(float)); + ASSERT_EQ(typeid(get_value(config)), typeid(double)); +} + + +TEST(GetValue, ComplexType) +{ + double real = 1.0; + double imag = -1.0; + pnode config{real}; + pnode array_config; + array_config.get_array() = {pnode{real}, pnode{imag}}; + + // Only one value + ASSERT_EQ(get_value>(config), + std::complex(real)); + ASSERT_EQ(get_value>(config), + std::complex(real)); + ASSERT_EQ(typeid(get_value>(config)), + typeid(std::complex)); + ASSERT_EQ(typeid(get_value>(config)), + typeid(std::complex)); + // Two value [real, imag] + ASSERT_EQ(get_value>(array_config), + std::complex(real, imag)); + ASSERT_EQ(get_value>(array_config), + std::complex(real, imag)); + ASSERT_EQ(typeid(get_value>(array_config)), + typeid(std::complex)); + ASSERT_EQ(typeid(get_value>(array_config)), + typeid(std::complex)); +} + + } // namespace