Skip to content

Commit

Permalink
add general intro, use snake_case, and fixed-width notation
Browse files Browse the repository at this point in the history
Co-authored-by: Marcel Koch <[email protected]>
  • Loading branch information
yhmtsai and MarcelKoch committed May 8, 2024
1 parent 5f65a31 commit a821795
Show file tree
Hide file tree
Showing 12 changed files with 122 additions and 113 deletions.
4 changes: 2 additions & 2 deletions core/config/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ deferred_factory_parameter<gko::LinOpFactory> parse(const pnode& config,
const registry& context,
const type_descriptor& td)
{
if (auto& obj = config.get("Type")) {
if (auto& obj = config.get("type")) {
auto func = detail::registry_accessor::get_build_map(context).at(
obj.get_string());
return func(config, context, td);
}
GKO_INVALID_STATE("Should contain Type property");
GKO_INVALID_STATE("Should contain type property");
}


Expand Down
4 changes: 2 additions & 2 deletions core/config/config_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ type_descriptor update_type(const pnode& config, const type_descriptor& td)
auto value_typestr = td.get_value_typestr();
auto index_typestr = td.get_index_typestr();

if (auto& obj = config.get("ValueType")) {
if (auto& obj = config.get("value_type")) {
value_typestr = obj.get_string();
}
if (auto& obj = config.get("IndexType")) {
if (auto& obj = config.get("index_type")) {
index_typestr = obj.get_string();
}
return type_descriptor{value_typestr, index_typestr};
Expand Down
81 changes: 25 additions & 56 deletions core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,21 @@ deferred_factory_parameter<gko::LinOpFactory> parse(
const pnode& config, const registry& context,
const type_descriptor& td = make_type_descriptor<>());

/**
* This function updates the default type setting from current config. Any type
* that is not specified in the config will fall back to the type stored in the
* current type_descriptor.
*/
type_descriptor update_type(const pnode& config, const type_descriptor& td);

/**
* get_stored_obj searches the object pointer stored in the registry by string
*/
template <typename T>
inline std::shared_ptr<T> get_stored_obj(const pnode& config,
const registry& context);
const registry& context)
{
std::shared_ptr<T> ptr;
using T_non_const = std::remove_const_t<T>;
ptr = detail::registry_accessor::get_data<T_non_const>(context,
config.get_string());
GKO_THROW_IF_INVALID(ptr.get() != nullptr, "Do not get the stored data");
return ptr;
}


/**
Expand Down Expand Up @@ -85,55 +87,6 @@ get_factory<const stop::CriterionFactory>(const pnode& config,
/**
* get_factory_vector will gives a vector of factory by calling get_factory.
*/
template <typename T>
inline std::vector<deferred_factory_parameter<T>> get_factory_vector(
const pnode& config, const registry& context, const type_descriptor& td);


/**
* get_value gets the corresponding type value from config.
*
* This is specialization for integral type
*/
template <typename IndexType>
inline
typename std::enable_if<std::is_integral<IndexType>::value, IndexType>::type
get_value(const pnode& config);

/**
* get_value gets the corresponding type value from config.
*
* This is specialization for floating point type
*/
template <typename ValueType>
inline typename std::enable_if<std::is_floating_point<ValueType>::value,
ValueType>::type
get_value(const pnode& config);

/**
* get_value gets the corresponding type value from config.
*
* This is specialization for complex type
*/
template <typename ValueType>
inline typename std::enable_if<gko::is_complex_s<ValueType>::value,
ValueType>::type
get_value(const pnode& config);


template <typename T>
inline std::shared_ptr<T> get_stored_obj(const pnode& config,
const registry& context)
{
std::shared_ptr<T> ptr;
using T_non_const = std::remove_const_t<T>;
ptr = detail::registry_accessor::get_data<T_non_const>(context,
config.get_string());
GKO_THROW_IF_INVALID(ptr.get() != nullptr, "Do not get the stored data");
return ptr;
}


template <typename T>
inline std::vector<deferred_factory_parameter<T>> get_factory_vector(
const pnode& config, const registry& context, const type_descriptor& td)
Expand All @@ -152,6 +105,11 @@ inline std::vector<deferred_factory_parameter<T>> get_factory_vector(
}


/**
* get_value gets the corresponding type value from config.
*
* This is specialization for integral type
*/
template <typename IndexType>
inline
typename std::enable_if<std::is_integral<IndexType>::value, IndexType>::type
Expand All @@ -165,6 +123,12 @@ inline
return static_cast<IndexType>(val);
}


/**
* get_value gets the corresponding type value from config.
*
* This is specialization for floating point type
*/
template <typename ValueType>
inline typename std::enable_if<std::is_floating_point<ValueType>::value,
ValueType>::type
Expand All @@ -179,6 +143,11 @@ get_value(const pnode& config)
return static_cast<ValueType>(val);
}

/**
* get_value gets the corresponding type value from config.
*
* This is specialization for complex type
*/
template <typename ValueType>
inline typename std::enable_if<gko::is_complex_s<ValueType>::value,
ValueType>::type
Expand Down
4 changes: 4 additions & 0 deletions core/config/dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ deferred_factory_parameter<ReturnType> dispatch(const pnode& config,
}

/**
* This function is to dispatch the type from runtime parameter.
* The concrete class need to have static member function
* `parse(pnode, registry, type_descriptor)`
*
* @param config the configuration
* @param context the registry context
* @param td the default type descriptor
Expand Down
10 changes: 5 additions & 5 deletions core/config/stop_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ get_factory<const stop::CriterionFactory>(const pnode& config,
gko::stop::CriterionFactory>(
const pnode&, const registry&, type_descriptor)>>
criterion_map{
{{"Time", configure_time},
{"Iteration", configure_iter},
{"ResidualNorm", configure_residual},
{"ImplicitResidualNorm", configure_implicit_residual}}};
return criterion_map.at(config.get("Type").get_string())(config,
{{"stop::Time", configure_time},
{"stop::Iteration", configure_iter},
{"stop::ResidualNorm", configure_residual},
{"stop::ImplicitResidualNorm", configure_implicit_residual}}};
return criterion_map.at(config.get("type").get_string())(config,
context, td);
}
GKO_THROW_IF_INVALID(!ptr.is_empty(), "Parse get nullptr in the end");
Expand Down
10 changes: 5 additions & 5 deletions core/config/type_descriptor_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ struct type_string {};
}

TYPE_STRING_OVERLOAD(void, "void");
TYPE_STRING_OVERLOAD(double, "double");
TYPE_STRING_OVERLOAD(float, "float");
TYPE_STRING_OVERLOAD(std::complex<double>, "complex<double>");
TYPE_STRING_OVERLOAD(std::complex<float>, "complex<float>");
TYPE_STRING_OVERLOAD(int32, "int");
TYPE_STRING_OVERLOAD(double, "float64");
TYPE_STRING_OVERLOAD(float, "float32");
TYPE_STRING_OVERLOAD(std::complex<double>, "complex<float64>");
TYPE_STRING_OVERLOAD(std::complex<float>, "complex<float32>");
TYPE_STRING_OVERLOAD(int32, "int32");
TYPE_STRING_OVERLOAD(int64, "int64");


Expand Down
19 changes: 10 additions & 9 deletions core/test/config/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class Config : public ::testing::Test {
: exec(gko::ReferenceExecutor::create()),
mtx(gko::initialize<Mtx>(
{{2, -1.0, 0.0}, {-1.0, 2, -1.0}, {0.0, -1.0, 2}}, exec)),
stop_config({{"Type", pnode{"Iteration"}}, {"max_iters", pnode{1}}})
stop_config(
{{"type", pnode{"stop::Iteration"}}, {"max_iters", pnode{1}}})
{}

std::shared_ptr<const gko::Executor> exec;
Expand All @@ -51,7 +52,8 @@ TEST_F(Config, GenerateObjectWithoutDefault)
{
auto reg = registry();

pnode p{{{"ValueType", pnode{"double"}}, {"criteria", this->stop_config}}};
pnode p{
{{"value_type", pnode{"float64"}}, {"criteria", this->stop_config}}};
auto obj = parse<LinOpFactoryType::Cg>(p, reg).on(this->exec);

ASSERT_NE(dynamic_cast<const gko::solver::Cg<double>::Factory*>(obj.get()),
Expand All @@ -67,7 +69,7 @@ TEST_F(Config, GenerateObjectWithData)
pnode p{{{"generated_preconditioner", pnode{"precond"}},
{"criteria", this->stop_config}}};
auto obj =
parse<LinOpFactoryType::Cg>(p, reg, type_descriptor{"float", "void"})
parse<LinOpFactoryType::Cg>(p, reg, type_descriptor{"float32", "void"})
.on(this->exec);

ASSERT_NE(dynamic_cast<gko::solver::Cg<float>::Factory*>(obj.get()),
Expand All @@ -83,8 +85,8 @@ TEST_F(Config, GenerateObjectWithPreconditioner)
{
auto reg = registry();
auto precond_node =
pnode{{{"Type", pnode{"solver::Cg"}}, {"criteria", this->stop_config}}};
pnode p{{{"ValueType", pnode{"double"}},
pnode{{{"type", pnode{"solver::Cg"}}, {"criteria", this->stop_config}}};
pnode p{{{"value_type", pnode{"float64"}},
{"criteria", this->stop_config},
{"preconditioner", precond_node}}};

Expand All @@ -102,21 +104,20 @@ TEST_F(Config, GenerateObjectWithPreconditioner)
TEST_F(Config, GenerateObjectWithCustomBuild)
{
configuration_map config_map;

config_map["Custom"] = [](const pnode& config, const registry& context,
const type_descriptor& td_for_child) {
return gko::solver::Bicg<double>::build().with_criteria(
gko::stop::Iteration::build().with_max_iters(2u));
};
auto reg = registry(config_map);
auto precond_node =
pnode{std::map<std::string, pnode>{{"Type", pnode{"Custom"}}}};
pnode p{{{"ValueType", pnode{"double"}},
pnode{std::map<std::string, pnode>{{"type", pnode{"Custom"}}}};
pnode p{{{"value_type", pnode{"float64"}},
{"criteria", this->stop_config},
{"preconditioner", precond_node}}};

auto obj =
parse<LinOpFactoryType::Cg>(p, reg, type_descriptor{"double", "void"})
parse<LinOpFactoryType::Cg>(p, reg, type_descriptor{"float64", "void"})
.on(this->exec);

ASSERT_NE(dynamic_cast<gko::solver::Cg<double>::Factory*>(obj.get()),
Expand Down
20 changes: 10 additions & 10 deletions core/test/config/type_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,21 @@ TEST(TypeDescriptor, TemplateCreate)
SCOPED_TRACE("defaule template");

Check warning on line 21 in core/test/config/type_descriptor.cpp

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"defaule" should be "default".
auto td = make_type_descriptor<>();

ASSERT_EQ(td.get_value_typestr(), "double");
ASSERT_EQ(td.get_index_typestr(), "int");
ASSERT_EQ(td.get_value_typestr(), "float64");
ASSERT_EQ(td.get_index_typestr(), "int32");
}
{
SCOPED_TRACE("specify valuetype");
auto td = make_type_descriptor<float>();

ASSERT_EQ(td.get_value_typestr(), "float");
ASSERT_EQ(td.get_index_typestr(), "int");
ASSERT_EQ(td.get_value_typestr(), "float32");
ASSERT_EQ(td.get_index_typestr(), "int32");
}
{
SCOPED_TRACE("specify all template");
auto td = make_type_descriptor<std::complex<float>, gko::int64>();

ASSERT_EQ(td.get_value_typestr(), "complex<float>");
ASSERT_EQ(td.get_value_typestr(), "complex<float32>");
ASSERT_EQ(td.get_index_typestr(), "int64");
}
{
Expand All @@ -54,15 +54,15 @@ TEST(TypeDescriptor, Constructor)
SCOPED_TRACE("defaule constructor");

Check warning on line 54 in core/test/config/type_descriptor.cpp

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"defaule" should be "default".
type_descriptor td;

ASSERT_EQ(td.get_value_typestr(), "double");
ASSERT_EQ(td.get_index_typestr(), "int");
ASSERT_EQ(td.get_value_typestr(), "float64");
ASSERT_EQ(td.get_index_typestr(), "int32");
}
{
SCOPED_TRACE("specify valuetype");
type_descriptor td("float");
type_descriptor td("float32");

ASSERT_EQ(td.get_value_typestr(), "float");
ASSERT_EQ(td.get_index_typestr(), "int");
ASSERT_EQ(td.get_value_typestr(), "float32");
ASSERT_EQ(td.get_index_typestr(), "int32");
}
{
SCOPED_TRACE("specify all parameters");
Expand Down
Loading

0 comments on commit a821795

Please sign in to comment.