Skip to content

Commit

Permalink
add get_value
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Aug 17, 2023
1 parent 877eb6d commit f04daec
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 36 deletions.
36 changes: 33 additions & 3 deletions core/config/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include <string>
#include <type_traits>


#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/lin_op.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/stop/criterion.hpp>

Expand Down Expand Up @@ -131,16 +134,43 @@ get_pointer_vector<const stop::CriterionFactory>(
std::shared_ptr<const Executor> exec, type_descriptor td);


template <typename IndexType, typename = typename std::enable_if<
std::is_integral<IndexType>::value>::type>
inline IndexType get_value(const pnode& config)
template <typename IndexType>
inline
typename std::enable_if<std::is_integral<IndexType>::value, IndexType>::type
get_value(const pnode& config)
{
auto val = config.get_data<long long int>();
assert(val <= std::numeric_limits<IndexType>::max() &&
val >= std::numeric_limits<IndexType>::min());
return static_cast<IndexType>(val);
}

template <typename ValueType>
inline typename std::enable_if<std::is_floating_point<ValueType>::value,
ValueType>::type
get_value(const pnode& config)
{
auto val = config.get_data<double>();
assert(val <= std::numeric_limits<ValueType>::max() &&
val >= std::numeric_limits<ValueType>::min());
return static_cast<ValueType>(val);
}

template <typename ValueType>
inline typename std::enable_if<gko::is_complex_s<ValueType>::value,
ValueType>::type
get_value(const pnode& config)
{
using real_type = gko::remove_complex<ValueType>;
if (config.is(pnode::status_t::object)) {
return static_cast<ValueType>(get_value<real_type>(config));
} else if (config.is(pnode::status_t::array)) {
return ValueType{get_value<real_type>(config.at(0)),
get_value<real_type>(config.at(1))};
}
GKO_INVALID_STATE("Can not get complex value");
}


#define SET_POINTER(_factory, _param_type, _param_name, _config, _context, \
_exec, _td) \
Expand Down
6 changes: 2 additions & 4 deletions core/config/stop_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ class ResidualNormConfigurer {
gko::config::type_descriptor td_for_child)
{
auto factory = stop::ResidualNorm<ValueType>::build();
// SET_VALUE(factory, remove_complex<ValueType>, reduction_factor,
// config);
SET_VALUE(factory, remove_complex<ValueType>, reduction_factor, config);
if (config.contains("baseline")) {
factory.with_baseline(
get_mode(config.at("baseline").get_data<std::string>()));
Expand Down Expand Up @@ -122,8 +121,7 @@ class ImplicitResidualNormConfigurer {
gko::config::type_descriptor td_for_child)
{
auto factory = stop::ImplicitResidualNorm<ValueType>::build();
// SET_VALUE(factory, remove_complex<ValueType>, reduction_factor,
// config);
SET_VALUE(factory, remove_complex<ValueType>, reduction_factor, config);
if (config.contains("baseline")) {
factory.with_baseline(
get_mode(config.at("baseline").get_data<std::string>()));
Expand Down
108 changes: 79 additions & 29 deletions core/test/config/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,17 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/stop/residual_norm.hpp>


#include "core/config/config.hpp"
#include "core/test/config/utils.hpp"
#include "core/test/utils.hpp"


namespace {


using gko::config::pnode;
using namespace gko::config;


template <typename T>
class Config : public ::testing::Test {
protected:
using value_type = double;
Expand All @@ -74,38 +75,32 @@ class Config : public ::testing::Test {
pnode stop_config;
};

TYPED_TEST_SUITE(Config, gko::test::ValueTypes, TypenameNameGenerator);

TEST_F(Config, GenerateMap) { ASSERT_NO_THROW(generate_config_map()); }

TYPED_TEST(Config, GenerateMap)
{
ASSERT_NO_THROW(gko::config::generate_config_map());
}


TYPED_TEST(Config, GenerateObjectWithoutDefault)
TEST_F(Config, GenerateObjectWithoutDefault)
{
auto config_map = gko::config::generate_config_map();
auto reg = gko::config::registry(config_map);
auto config_map = generate_config_map();
auto reg = registry(config_map);

pnode p{{{"ValueType", pnode{"double"}}, {"criteria", this->stop_config}}};
auto obj = gko::config::build_from_config<0>(p, reg, this->exec);
auto obj = build_from_config<0>(p, reg, this->exec);

ASSERT_NE(dynamic_cast<gko::solver::Cg<double>::Factory*>(obj.get()),
nullptr);
}


TYPED_TEST(Config, GenerateObjectWithData)
TEST_F(Config, GenerateObjectWithData)
{
auto config_map = gko::config::generate_config_map();
auto reg = gko::config::registry(config_map);
auto config_map = generate_config_map();
auto reg = registry(config_map);
reg.emplace("precond", this->mtx);

pnode p{{{"generated_preconditioner", pnode{"precond"}},
{"criteria", this->stop_config}}};
auto obj =
gko::config::build_from_config<0>(p, reg, this->exec, {"float", ""});
auto obj = build_from_config<0>(p, reg, this->exec, {"float", ""});

ASSERT_NE(dynamic_cast<gko::solver::Cg<float>::Factory*>(obj.get()),
nullptr);
Expand All @@ -116,15 +111,15 @@ TYPED_TEST(Config, GenerateObjectWithData)
}


TYPED_TEST(Config, GenerateObjectWithPreconditioner)
TEST_F(Config, GenerateObjectWithPreconditioner)
{
auto config_map = gko::config::generate_config_map();
auto reg = gko::config::registry(config_map);
auto config_map = generate_config_map();
auto reg = registry(config_map);

pnode p{{{"ValueType", pnode{"double"}}, {"criteria", this->stop_config}}};
p.get_list()["preconditioner"] =
pnode{{{"Type", pnode{"Cg"}}, {"criteria", this->stop_config}}};
auto obj = gko::config::build_from_config<0>(p, reg, this->exec);
auto obj = build_from_config<0>(p, reg, this->exec);

ASSERT_NE(dynamic_cast<gko::solver::Cg<double>::Factory*>(obj.get()),
nullptr);
Expand All @@ -135,25 +130,23 @@ TYPED_TEST(Config, GenerateObjectWithPreconditioner)
}


TYPED_TEST(Config, GenerateObjectWithCustomBuild)
TEST_F(Config, GenerateObjectWithCustomBuild)
{
auto config_map = gko::config::generate_config_map();
auto config_map = generate_config_map();

config_map["Custom"] = [](const gko::config::pnode& config,
const gko::config::registry& context,
config_map["Custom"] = [](const pnode& config, const registry& context,
std::shared_ptr<const gko::Executor>& exec,
gko::config::type_descriptor td_for_child) {
type_descriptor td_for_child) {
return gko::solver::Bicg<double>::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(2u).on(exec))
.on(exec);
};
auto reg = gko::config::registry(config_map);
auto reg = registry(config_map);

pnode p{{{"ValueType", pnode{"double"}}, {"criteria", this->stop_config}}};
p.get_list()["preconditioner"] = pnode{{{"Type", pnode{"Custom"}}}};
auto obj =
gko::config::build_from_config<0>(p, reg, this->exec, {"double", ""});
auto obj = build_from_config<0>(p, reg, this->exec, {"double", ""});

ASSERT_NE(dynamic_cast<gko::solver::Cg<double>::Factory*>(obj.get()),
nullptr);
Expand All @@ -165,4 +158,61 @@ TYPED_TEST(Config, GenerateObjectWithCustomBuild)
}


TEST(GetValue, IndexType)
{
long long int value = 123;
pnode config{value};

ASSERT_EQ(get_value<int>(config), value);
ASSERT_EQ(get_value<long>(config), value);
ASSERT_EQ(get_value<unsigned>(config), value);
ASSERT_EQ(get_value<long long int>(config), value);
ASSERT_EQ(typeid(get_value<int>(config)), typeid(int));
ASSERT_EQ(typeid(get_value<long>(config)), typeid(long));
ASSERT_EQ(typeid(get_value<unsigned>(config)), typeid(unsigned));
ASSERT_EQ(typeid(get_value<long long int>(config)), typeid(long long int));
}


TEST(GetValue, RealType)
{
double value = 1.0;
pnode config{value};

ASSERT_EQ(get_value<float>(config), value);
ASSERT_EQ(get_value<double>(config), value);
ASSERT_EQ(typeid(get_value<float>(config)), typeid(float));
ASSERT_EQ(typeid(get_value<double>(config)), typeid(double));
}


TEST(GetValue, ComplexType)
{
double real = 1.0;
double imag = -1.0;
pnode config{real};
pnode array_config;
array_config.get_array() = {pnode{real}, pnode{imag}};

// Only one value
ASSERT_EQ(get_value<std::complex<float>>(config),
std::complex<float>(real));
ASSERT_EQ(get_value<std::complex<double>>(config),
std::complex<double>(real));
ASSERT_EQ(typeid(get_value<std::complex<float>>(config)),
typeid(std::complex<float>));
ASSERT_EQ(typeid(get_value<std::complex<double>>(config)),
typeid(std::complex<double>));
// Two value [real, imag]
ASSERT_EQ(get_value<std::complex<float>>(array_config),
std::complex<float>(real, imag));
ASSERT_EQ(get_value<std::complex<double>>(array_config),
std::complex<double>(real, imag));
ASSERT_EQ(typeid(get_value<std::complex<float>>(array_config)),
typeid(std::complex<float>));
ASSERT_EQ(typeid(get_value<std::complex<double>>(array_config)),
typeid(std::complex<double>));
}


} // namespace

0 comments on commit f04daec

Please sign in to comment.