diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index 1b5f9237612..acbb4efac31 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -18,6 +18,7 @@ target_sources(ginkgo base/perturbation.cpp base/timer.cpp base/version.cpp + config/property_tree.cpp distributed/partition.cpp factorization/cholesky.cpp factorization/elimination_forest.cpp diff --git a/core/config/property_tree.cpp b/core/config/property_tree.cpp new file mode 100644 index 00000000000..47e627d21e6 --- /dev/null +++ b/core/config/property_tree.cpp @@ -0,0 +1,140 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + + +#include + + +namespace gko { +namespace config { + + +pnode::pnode() : tag_(tag_t::empty) {} + + +pnode::pnode(bool boolean) : tag_(tag_t::boolean) +{ + union_data_.boolean_ = boolean; +} + + +pnode::pnode(const std::string& str) : tag_(tag_t::string) { str_ = str; } + + +pnode::pnode(double real) : tag_(tag_t::real) { union_data_.real_ = real; } + + +pnode::pnode(const char* str) : pnode(std::string(str)) {} + + +pnode::pnode(const array_type& array) : tag_(tag_t::array), array_(array) {} + + +pnode::pnode(const map_type& map) : tag_(tag_t::map), map_(map) {} + + +pnode::operator bool() const noexcept { return tag_ != tag_t::empty; } + + +pnode::tag_t pnode::get_tag() const { return tag_; } + +const pnode::array_type& pnode::get_array() const +{ + this->throw_if_not_contain(tag_t::array); + return array_; +} + + +const pnode::map_type& pnode::get_map() const +{ + this->throw_if_not_contain(tag_t::map); + return map_; +} + + +bool pnode::get_boolean() const +{ + this->throw_if_not_contain(tag_t::boolean); + return union_data_.boolean_; +} + + +std::int64_t pnode::get_integer() const +{ + this->throw_if_not_contain(tag_t::integer); + return union_data_.integer_; +} + + +double pnode::get_real() const +{ + this->throw_if_not_contain(tag_t::real); + return union_data_.real_; +} + + +const std::string& pnode::get_string() const +{ + this->throw_if_not_contain(tag_t::string); + return str_; +} + + +const pnode& pnode::get(const std::string& key) const +{ + this->throw_if_not_contain(tag_t::map); + auto it = map_.find(key); + if (it != map_.end()) { + return map_.at(key); + } else { + return pnode::empty_node(); + } +} + +const pnode& pnode::get(int index) const +{ + this->throw_if_not_contain(tag_t::array); + return array_.at(index); +} + + +void pnode::throw_if_not_contain(tag_t tag) const +{ + static auto str_tag = [](tag_t tag) -> std::string { + if (tag == tag_t::empty) { + return "empty"; + } else if (tag == tag_t::array) { + return "array"; + } else if (tag == tag_t::map) { + return "map"; + } else if (tag == tag_t::real) { + return "real"; + } else if (tag == tag_t::boolean) { + return "boolean"; + } else if (tag == tag_t::integer) { + return "integer"; + } else if (tag == tag_t::string) { + return "string"; + } else { + return "unknown"; + } + }; + bool is_valid = (tag_ == tag); + std::string msg = + "Contains " + str_tag(tag_) + ", but try to get " + str_tag(tag); + GKO_THROW_IF_INVALID(is_valid, msg); +} + + +const pnode& pnode::empty_node() +{ + static pnode empty_pnode{}; + return empty_pnode; +} + + +} // namespace config +} // namespace gko diff --git a/core/test/CMakeLists.txt b/core/test/CMakeLists.txt index 69f7ddd749e..e56001cbf4d 100644 --- a/core/test/CMakeLists.txt +++ b/core/test/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(gtest) add_subdirectory(accessor) add_subdirectory(base) add_subdirectory(components) +add_subdirectory(config) if(GINKGO_BUILD_MPI) add_subdirectory(mpi) endif() diff --git a/core/test/config/CMakeLists.txt b/core/test/config/CMakeLists.txt new file mode 100644 index 00000000000..e842152634c --- /dev/null +++ b/core/test/config/CMakeLists.txt @@ -0,0 +1 @@ +ginkgo_create_test(property_tree) diff --git a/core/test/config/property_tree.cpp b/core/test/config/property_tree.cpp new file mode 100644 index 00000000000..a552a6c08d8 --- /dev/null +++ b/core/test/config/property_tree.cpp @@ -0,0 +1,187 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + + +#include +#include +#include +#include + + +#include + + +#include + + +using namespace gko::config; + + +void assert_others_throw(const pnode& node) +{ + auto tag = node.get_tag(); + if (tag != pnode::tag_t::array) { + ASSERT_THROW(node.get_array(), gko::InvalidStateError); + ASSERT_THROW(node.get(0), gko::InvalidStateError); + } + if (tag != pnode::tag_t::map) { + ASSERT_THROW(node.get_map(), gko::InvalidStateError); + ASSERT_THROW(node.get("random"), gko::InvalidStateError); + } + if (tag != pnode::tag_t::boolean) { + ASSERT_THROW(node.get_boolean(), gko::InvalidStateError); + } + if (tag != pnode::tag_t::integer) { + ASSERT_THROW(node.get_integer(), gko::InvalidStateError); + } + if (tag != pnode::tag_t::real) { + ASSERT_THROW(node.get_real(), gko::InvalidStateError); + } + if (tag != pnode::tag_t::string) { + ASSERT_THROW(node.get_string(), gko::InvalidStateError); + } +} + + +TEST(PropertyTree, CreateEmpty) +{ + pnode root; + + ASSERT_EQ(root.get_tag(), pnode::tag_t::empty); + assert_others_throw(root); +} + + +TEST(PropertyTree, CreateStringData) +{ + pnode str(std::string("test_name")); + pnode char_str("test_name"); + + ASSERT_EQ(str.get_tag(), pnode::tag_t::string); + ASSERT_EQ(str.get_string(), "test_name"); + assert_others_throw(str); + ASSERT_EQ(char_str.get_tag(), pnode::tag_t::string); + ASSERT_EQ(char_str.get_string(), "test_name"); + assert_others_throw(char_str); +} + + +TEST(PropertyTree, CreateBoolData) +{ + pnode boolean(true); + + ASSERT_EQ(boolean.get_tag(), pnode::tag_t::boolean); + ASSERT_EQ(boolean.get_boolean(), true); + assert_others_throw(boolean); +} + + +TEST(PropertyTree, CreateIntegerData) +{ + pnode integer(1); + pnode integer_8(std::int8_t(1)); + pnode integer_16(std::int16_t(1)); + pnode integer_32(std::int32_t(1)); + pnode integer_64(std::int64_t(1)); + pnode integer_u8(std::uint8_t(1)); + pnode integer_u16(std::uint16_t(1)); + pnode integer_u32(std::uint32_t(1)); + pnode integer_u64(std::uint64_t(1)); + + + for (auto& node : {integer, integer_8, integer_16, integer_32, integer_64, + integer_u8, integer_u16, integer_u32, integer_u64}) { + ASSERT_EQ(node.get_tag(), pnode::tag_t::integer); + ASSERT_EQ(node.get_integer(), 1); + assert_others_throw(node); + } + ASSERT_THROW( + pnode(std::uint64_t(std::numeric_limits::max()) + 1), + std::runtime_error); +} + + +TEST(PropertyTree, CreateRealData) +{ + pnode real(1.0); + pnode real_float(float(1.0)); + pnode real_double(double(1.0)); + + for (auto& node : {real, real_double, real_float}) { + ASSERT_EQ(node.get_tag(), pnode::tag_t::real); + ASSERT_EQ(node.get_real(), 1.0); + assert_others_throw(node); + } +} + + +TEST(PropertyTree, CreateMap) +{ + pnode root({{"p0", pnode{1.0}}, + {"p1", pnode{1}}, + {"p2", pnode{pnode::map_type{{"p0", pnode{"test"}}}}}}); + + ASSERT_EQ(root.get_tag(), pnode::tag_t::map); + ASSERT_EQ(root.get("p0").get_real(), 1.0); + ASSERT_EQ(root.get("p1").get_integer(), 1); + ASSERT_EQ(root.get("p2").get_tag(), pnode::tag_t::map); + ASSERT_EQ(root.get("p2").get("p0").get_string(), "test"); + assert_others_throw(root); +} + + +TEST(PropertyTree, CreateArray) +{ + pnode root(pnode::array_type{pnode{"123"}, pnode{"456"}, pnode{"789"}}); + + ASSERT_EQ(root.get_tag(), pnode::tag_t::array); + ASSERT_EQ(root.get(0).get_string(), "123"); + ASSERT_EQ(root.get(1).get_string(), "456"); + ASSERT_EQ(root.get(2).get_string(), "789"); + ASSERT_THROW(root.get(3), std::out_of_range); + ASSERT_EQ(root.get_array().size(), 3); + assert_others_throw(root); +} + + +TEST(PropertyTree, ConversionToBool) +{ + pnode empty; + pnode non_empty{"test"}; + + ASSERT_FALSE(empty); + ASSERT_TRUE(non_empty); +} + + +TEST(PropertyTree, ReturnEmptyIfNotFound) +{ + pnode ptree(pnode::map_type{{"test", pnode{2}}}); + + auto obj = ptree.get("na"); + + ASSERT_EQ(obj.get_tag(), pnode::tag_t::empty); +} + + +TEST(PropertyTree, UseInCondition) +{ + pnode ptree(pnode::map_type{{"test", pnode{2}}}); + int first = 0; + int second = 0; + + if (auto obj = ptree.get("test")) { + first = static_cast(obj.get_integer()); + } + if (auto obj = ptree.get("na")) { + second = -1; + } else { + second = 1; + } + + ASSERT_EQ(first, 2); + ASSERT_EQ(second, 1); +} diff --git a/include/ginkgo/core/config/property_tree.hpp b/include/ginkgo/core/config/property_tree.hpp new file mode 100644 index 00000000000..e1ef2f00dfb --- /dev/null +++ b/include/ginkgo/core/config/property_tree.hpp @@ -0,0 +1,214 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_PUBLIC_CORE_CONFIG_PROPERTY_TREE_HPP_ +#define GKO_PUBLIC_CORE_CONFIG_PROPERTY_TREE_HPP_ + + +#include +#include +#include +#include +#include +#include +#include + + +namespace gko { +namespace config { + + +/** + * pnode describes a tree of properties. + * + * A pnode can either be empty, hold a value (a string, integer, real, or bool), + * contain an array of pnode., or contain a mapping between strings and pnodes. + */ +class pnode { +public: + using key_type = std::string; + using map_type = std::map; + using array_type = std::vector; + + /** + * tag_t is the indicator for the current node storage. + */ + enum class tag_t { empty, array, boolean, real, integer, string, map }; + + /** + * Default constructor: create an empty node + */ + explicit pnode(); + + /** + * Constructor for bool + * + * @param boolean the bool type value + */ + explicit pnode(bool boolean); + + /** + * Constructor for integer with all integer type + * + * @tparam T input type + * + * @param integer the integer type value + */ + template ::value>* = nullptr> + explicit pnode(T integer); + + /** + * Constructor for string + * + * @param str string type value + */ + explicit pnode(const std::string& str); + + /** + * Constructor for char* (otherwise, it will use bool) + * + * @param str the string like "..." + */ + explicit pnode(const char* str); + + /** + * Constructor for double (and also float) + * + * @param real the floating point type value + */ + explicit pnode(double real); + + /** + * Constructor for array + * + * @param array an pnode array + */ + explicit pnode(const array_type& array); + + /** + * Constructor for map + * + * @param map a (string, pnode)-map + */ + explicit pnode(const map_type& map); + + /** + * bool conversion. It's true if and only if it is not empty. + */ + explicit operator bool() const noexcept; + + /** + * Get the current node tag. + * + * @return the tag + */ + tag_t get_tag() const; + + /** + * Access the array stored in this property node. Throws + * `gko::InvalidStateError` if the property node does not store an array. + * + * @return the array + */ + const array_type& get_array() const; + + /** + * Access the map stored in this property node. Throws + * `gko::InvalidStateError` if the property node does not store a map. + * + * @return the map + */ + const map_type& get_map() const; + + /** + * Access the boolean value stored in this property node. Throws + * `gko::InvalidStateError` if the property node does not store a boolean + * value. + * + * @return the boolean value + */ + bool get_boolean() const; + + /** + * * Access the integer value stored in this property node. Throws + * `gko::InvalidStateError` if the property node does not store an integer + * value. + * + * @return the integer value + */ + std::int64_t get_integer() const; + + /** + * Access the real floating point value stored in this property node. Throws + * `gko::InvalidStateError` if the property node does not store a real value + * + * @return the real floating point value + */ + double get_real() const; + + /** + * Access the string stored in this property node. Throws + * `gko::InvalidStateError` if the property node does not store a string. + * + * @return the string + */ + const std::string& get_string() const; + + /** + * This function is to access the data under the map. It will throw error + * when it does not hold a map. When access non-existent key in the map, it + * will return an empty node. + * + * @param key the key for the node of the map + * + * @return node. If the map does not have the key, return + * an empty node. + */ + const pnode& get(const std::string& key) const; + + /** + * This function is to access the data under the array. It will throw error + * when it does not hold an array or access out-of-bound index. + * + * @param index the node index in array + * + * @return node. + */ + const pnode& get(int index) const; + +private: + void throw_if_not_contain(tag_t tag) const; + + static const pnode& empty_node(); + + tag_t tag_; + array_type array_; // for array + map_type map_; // for map + // value + std::string str_; + union { + std::int64_t integer_; + double real_; + bool boolean_; + } union_data_; +}; + + +template ::value>*> +pnode::pnode(T integer) : tag_(tag_t::integer) +{ + if (integer > std::numeric_limits::max() || + (std::is_signed::value && + integer < std::numeric_limits::min())) { + throw std::runtime_error("The input is out of the range of int64_t."); + } + union_data_.integer_ = static_cast(integer); +} + + +} // namespace config +} // namespace gko + +#endif // GKO_PUBLIC_CORE_CONFIG_PROPERTY_TREE_HPP_ diff --git a/include/ginkgo/ginkgo.hpp b/include/ginkgo/ginkgo.hpp index 57add89ac08..dede45cf6e6 100644 --- a/include/ginkgo/ginkgo.hpp +++ b/include/ginkgo/ginkgo.hpp @@ -52,6 +52,8 @@ #include #include +#include + #include #include #include diff --git a/test/test_install/test_install.cpp b/test/test_install/test_install.cpp index 0a63f129273..48252ef9bbe 100644 --- a/test/test_install/test_install.cpp +++ b/test/test_install/test_install.cpp @@ -305,6 +305,11 @@ int main() auto test = gko::version_info::get().header_version; } + // core/config.config.hpp + { + auto test = gko::config::pnode(); + } + // core/factorization/par_ilu.hpp { auto test = gko::factorization::ParIlu<>::build().on(exec);