Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factory config #1392

Merged
merged 25 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6340bd4
try
yhmtsai Aug 9, 2023
70f098e
change the order and put it into protected
yhmtsai Aug 10, 2023
284afed
add type descripitor
yhmtsai Aug 11, 2023
9d11041
add the general type dispatch
yhmtsai Aug 11, 2023
c272d6b
use property_tree as config type
yhmtsai Aug 15, 2023
6ebce2e
update the usage of property tree
yhmtsai Aug 15, 2023
bcc54a1
move the function into internal not protected
yhmtsai Aug 15, 2023
1f9386c
update stop config
yhmtsai Aug 16, 2023
0c25063
add registry test
yhmtsai Aug 16, 2023
2941cb2
throw gko::Error
yhmtsai Aug 17, 2023
d52036d
add get_value
yhmtsai Aug 17, 2023
8f66c85
update usage of property tree
yhmtsai Aug 22, 2023
88a41fa
use deferred_factory_parameter (generate param)
yhmtsai Oct 11, 2023
557c6d6
adapt with the explicit deferred type
yhmtsai Oct 11, 2023
bf30b5d
adapt the corresponding changes
yhmtsai Mar 28, 2024
a2cc99c
reduce the macro usage
yhmtsai Mar 28, 2024
b26951c
export less function in public and move def. to cpp
yhmtsai Mar 28, 2024
3db9203
move the def to source of registry
yhmtsai Mar 29, 2024
da4a31b
adapt the changes
yhmtsai Apr 14, 2024
92bf10a
add type_descriptor class
yhmtsai Apr 17, 2024
4535b58
split file, update doc and name, couple type_list and str
yhmtsai Apr 29, 2024
3b6ff47
combine map together
yhmtsai Apr 29, 2024
69de11a
rename, always contain gko map, move getter into protect
yhmtsai Apr 30, 2024
5412de5
add general intro, use snake_case, and fixed-width notation
yhmtsai May 8, 2024
2629167
update documentation, rename, complex number and rm stop::, index_type
yhmtsai May 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ target_sources(ginkgo
base/segmented_array.cpp
base/timer.cpp
base/version.cpp
config/config.cpp
config/config_helper.cpp
config/property_tree.cpp
config/registry.cpp
config/stop_config.cpp
config/type_descriptor.cpp
distributed/index_map.cpp
distributed/partition.cpp
factorization/cholesky.cpp
Expand Down
37 changes: 37 additions & 0 deletions core/config/config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include <ginkgo/core/config/config.hpp>


#include <map>


#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/solver/solver_base.hpp>


#include "core/config/config_helper.hpp"
#include "core/config/registry_accessor.hpp"


namespace gko {
namespace config {


deferred_factory_parameter<gko::LinOpFactory> parse(const pnode& config,
const registry& context,
const type_descriptor& td)
{
if (auto& obj = config.get("type")) {
auto func = detail::registry_accessor::get_build_map(context).at(
obj.get_string());
return func(config, context, td);
}
GKO_INVALID_STATE("Should contain type property");
}


} // namespace config
} // namespace gko
65 changes: 65 additions & 0 deletions core/config/config_helper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include <type_traits>


#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/lin_op.hpp>
#include <ginkgo/core/config/registry.hpp>


#include "core/config/config_helper.hpp"
#include "core/config/registry_accessor.hpp"
#include "core/config/stop_config.hpp"

namespace gko {
namespace config {


template <>
deferred_factory_parameter<const LinOpFactory>
parse_or_get_factory<const LinOpFactory>(const pnode& config,
const registry& context,
const type_descriptor& td)
{
if (config.get_tag() == pnode::tag_t::string) {
return detail::registry_accessor::get_data<LinOpFactory>(
context, config.get_string());
} else if (config.get_tag() == pnode::tag_t::map) {
return parse(config, context, td);
} else {
GKO_INVALID_STATE("The data of config is not valid.");
}
yhmtsai marked this conversation as resolved.
Show resolved Hide resolved
}


template <>
deferred_factory_parameter<const stop::CriterionFactory>
parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
const type_descriptor& td)
{
if (config.get_tag() == pnode::tag_t::string) {
return detail::registry_accessor::get_data<stop::CriterionFactory>(
context, config.get_string());
} else if (config.get_tag() == pnode::tag_t::map) {
static std::map<std::string,
std::function<deferred_factory_parameter<
gko::stop::CriterionFactory>(
const pnode&, const registry&, type_descriptor)>>
criterion_map{
{{"Time", configure_time},
{"Iteration", configure_iter},
{"ResidualNorm", configure_residual},
{"ImplicitResidualNorm", configure_implicit_residual}}};
return criterion_map.at(config.get("type").get_string())(config,
context, td);
} else {
GKO_INVALID_STATE("The data of config is not valid.");
}
}

} // namespace config
} // namespace gko
180 changes: 180 additions & 0 deletions core/config/config_helper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_CONFIG_CONFIG_HELPER_HPP_
#define GKO_CORE_CONFIG_CONFIG_HELPER_HPP_


#include <string>
#include <type_traits>


#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/lin_op.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/solver/solver_base.hpp>
#include <ginkgo/core/stop/criterion.hpp>


#include "core/config/registry_accessor.hpp"


namespace gko {
namespace config {


/**
* LinOpFactoryType enum is to avoid forward declaration, linopfactory header,
* two template versions of parse
*/
enum class LinOpFactoryType : int { Cg = 0 };
MarcelKoch marked this conversation as resolved.
Show resolved Hide resolved


/**
* It is only an intermediate step after dispatching the class base type. Each
* implementation needs to deal with the template selection.
*/
template <LinOpFactoryType flag>
deferred_factory_parameter<gko::LinOpFactory> parse(
const pnode& config, const registry& context,
const type_descriptor& td = make_type_descriptor<>());
yhmtsai marked this conversation as resolved.
Show resolved Hide resolved


/**
* get_stored_obj searches the object pointer stored in the registry by string
*/
template <typename T>
inline std::shared_ptr<T> get_stored_obj(const pnode& config,
const registry& context)
{
std::shared_ptr<T> ptr;
using T_non_const = std::remove_const_t<T>;
ptr = detail::registry_accessor::get_data<T_non_const>(context,
config.get_string());
GKO_THROW_IF_INVALID(ptr.get() != nullptr, "Do not get the stored data");
return ptr;
}


/**
* Build the factory from config (map) or search the pointers in
* the registry by string.
*/
template <typename T>
deferred_factory_parameter<T> parse_or_get_factory(const pnode& config,
const registry& context,
const type_descriptor& td);

/**
* specialize for const LinOpFactory
*/
template <>
deferred_factory_parameter<const LinOpFactory>
parse_or_get_factory<const LinOpFactory>(const pnode& config,
const registry& context,
const type_descriptor& td);

/**
* specialize for const stop::CriterionFactory
*/
template <>
deferred_factory_parameter<const stop::CriterionFactory>
parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
const type_descriptor& td);

/**
* give a vector of factory by calling parse_or_get_factory.
*/
template <typename T>
inline std::vector<deferred_factory_parameter<T>> parse_or_get_factory_vector(
const pnode& config, const registry& context, const type_descriptor& td)
{
std::vector<deferred_factory_parameter<T>> res;
if (config.get_tag() == pnode::tag_t::array) {
for (const auto& it : config.get_array()) {
res.push_back(parse_or_get_factory<T>(it, context, td));
}
} else {
// only one config can be passed without array
res.push_back(parse_or_get_factory<T>(config, context, td));
}

return res;
}


/**
* get_value gets the corresponding type value from config.
*
* This is specialization for integral type
*/
template <typename IndexType>
inline std::enable_if_t<std::is_integral<IndexType>::value, IndexType>
get_value(const pnode& config)
{
auto val = config.get_integer();
GKO_THROW_IF_INVALID(
val <= std::numeric_limits<IndexType>::max() &&
val >= std::numeric_limits<IndexType>::min(),
"the config value is out of the range of the require type.");
return static_cast<IndexType>(val);
}


/**
* get_value gets the corresponding type value from config.
*
* This is specialization for floating point type
*/
template <typename ValueType>
inline std::enable_if_t<std::is_floating_point<ValueType>::value, ValueType>
get_value(const pnode& config)
{
auto val = config.get_real();
// the max, min of floating point only consider positive value.
GKO_THROW_IF_INVALID(
val <= std::numeric_limits<ValueType>::max() &&
val >= -std::numeric_limits<ValueType>::max(),
"the config value is out of the range of the require type.");
return static_cast<ValueType>(val);
}

/**
* get_value gets the corresponding type value from config.
*
* This is specialization for complex type
*/
template <typename ValueType>
inline std::enable_if_t<gko::is_complex_s<ValueType>::value, ValueType>
get_value(const pnode& config)
{
using real_type = gko::remove_complex<ValueType>;
if (config.get_tag() == pnode::tag_t::real) {
return static_cast<ValueType>(get_value<real_type>(config));
} else if (config.get_tag() == pnode::tag_t::array) {
real_type real(0);
real_type imag(0);
if (config.get_array().size() >= 1) {
real = get_value<real_type>(config.get(0));
}
if (config.get_array().size() >= 2) {
imag = get_value<real_type>(config.get(1));
}
GKO_THROW_IF_INVALID(
config.get_array().size() <= 2,
"complex value array expression only accept up to two elements");
return ValueType{real, imag};
}
GKO_INVALID_STATE("Can not get complex value");
}


} // namespace config
} // namespace gko


#endif // GKO_CORE_CONFIG_CONFIG_HELPER_HPP_
Loading
Loading