Skip to content

Commit

Permalink
other solver except for multigrid
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Aug 21, 2023
1 parent fd5b59a commit ad2ee29
Show file tree
Hide file tree
Showing 6 changed files with 603 additions and 33 deletions.
20 changes: 15 additions & 5 deletions core/config/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,21 @@ namespace config {

buildfromconfig_map generate_config_map()
{
return {{"Cg", build_from_config<LinOpFactoryType::Cg>},
{"Bicg", build_from_config<LinOpFactoryType::Bicg>},
{"Bicgstab", build_from_config<LinOpFactoryType::Bicgstab>},
{"Cgs", build_from_config<LinOpFactoryType::Cgs>},
{"Fcg", build_from_config<LinOpFactoryType::Fcg>}};
return {
{"Cg", build_from_config<LinOpFactoryType::Cg>},
{"Bicg", build_from_config<LinOpFactoryType::Bicg>},
{"Bicgstab", build_from_config<LinOpFactoryType::Bicgstab>},
{"Cgs", build_from_config<LinOpFactoryType::Cgs>},
{"Fcg", build_from_config<LinOpFactoryType::Fcg>},
{"Ir", build_from_config<LinOpFactoryType::Ir>},
{"Idr", build_from_config<LinOpFactoryType::Idr>},
{"Gcr", build_from_config<LinOpFactoryType::Gcr>},
{"Gmres", build_from_config<LinOpFactoryType::Gmres>},
{"CbGmres", build_from_config<LinOpFactoryType::CbGmres>},
{"Direct", build_from_config<LinOpFactoryType::Direct>},
{"LowerTrs", build_from_config<LinOpFactoryType::LowerTrs>},
{"UpperTrs", build_from_config<LinOpFactoryType::UpperTrs>},
};
}


Expand Down
36 changes: 33 additions & 3 deletions core/config/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/lin_op.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/solver/solver_base.hpp>
#include <ginkgo/core/stop/criterion.hpp>


Expand Down Expand Up @@ -134,10 +136,19 @@ get_pointer_vector<const stop::CriterionFactory>(
std::shared_ptr<const Executor> exec, type_descriptor td);


template <typename ValueType>
inline typename std::enable_if<std::is_same<ValueType, bool>::value, bool>::type
get_value(const pnode& config)
{
auto val = config.get_data<bool>();
return val;
}

template <typename IndexType>
inline
typename std::enable_if<std::is_integral<IndexType>::value, IndexType>::type
get_value(const pnode& config)
inline typename std::enable_if<std::is_integral<IndexType>::value &&
!std::is_same<IndexType, bool>::value,
IndexType>::type
get_value(const pnode& config)
{
auto val = config.get_data<long long int>();
assert(val <= std::numeric_limits<IndexType>::max() &&
Expand Down Expand Up @@ -171,6 +182,23 @@ get_value(const pnode& config)
GKO_INVALID_STATE("Can not get complex value");
}

template <typename ValueType>
inline typename std::enable_if<
std::is_same<ValueType, solver::initial_guess_mode>::value,
solver::initial_guess_mode>::type
get_value(const pnode& config)
{
auto val = config.get_data<std::string>();
if (val == "zero") {
return solver::initial_guess_mode::zero;
} else if (val == "rhs") {
return solver::initial_guess_mode::rhs;
} else if (val == "provided") {
return solver::initial_guess_mode::provided;
}
GKO_INVALID_STATE("Wrong value for initial_guess_mode");
}


#define SET_POINTER(_factory, _param_type, _param_name, _config, _context, \
_exec, _td) \
Expand Down Expand Up @@ -228,6 +256,8 @@ TYPE_STRING_OVERLOAD(double, "double");
TYPE_STRING_OVERLOAD(float, "float");
TYPE_STRING_OVERLOAD(std::complex<double>, "complex<double>");
TYPE_STRING_OVERLOAD(std::complex<float>, "complex<float>");
TYPE_STRING_OVERLOAD(gko::int32, "int");
TYPE_STRING_OVERLOAD(gko::int64, "int64");


} // namespace config
Expand Down
2 changes: 2 additions & 0 deletions core/config/dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/synthesizer/containers.hpp>
Expand Down Expand Up @@ -111,6 +112,7 @@ std::unique_ptr<ReturnType> dispatch(std::string str, const pnode& config,
using value_type_list =
syn::type_list<double, float, std::complex<double>, std::complex<float>>;

using index_type_list = syn::type_list<gko::int32, gko::int64>;

} // namespace config
} // namespace gko
Expand Down
265 changes: 265 additions & 0 deletions core/config/solver_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,23 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/solver/bicg.hpp>
#include <ginkgo/core/solver/bicgstab.hpp>
#include <ginkgo/core/solver/cb_gmres.hpp>
#include <ginkgo/core/solver/cg.hpp>
#include <ginkgo/core/solver/cgs.hpp>
#include <ginkgo/core/solver/direct.hpp>
#include <ginkgo/core/solver/fcg.hpp>
#include <ginkgo/core/solver/gcr.hpp>
#include <ginkgo/core/solver/gmres.hpp>
#include <ginkgo/core/solver/idr.hpp>
#include <ginkgo/core/solver/ir.hpp>
#include <ginkgo/core/solver/triangular.hpp>


#include "core/config/config.hpp"
#include "core/config/dispatch.hpp"
#include "core/config/solver_config.hpp"


namespace gko {
namespace config {

Expand Down Expand Up @@ -126,5 +134,262 @@ std::unique_ptr<gko::LinOpFactory> build_from_config<LinOpFactoryType::Fcg>(
}


template <typename ValueType>
class IrConfigurator {
public:
static std::unique_ptr<typename solver::Ir<ValueType>::Factory>
build_from_config(const pnode& config, const registry& context,
std::shared_ptr<const Executor> exec,
type_descriptor td_for_child)
{
auto factory = solver::Ir<ValueType>::build();
SET_POINTER_VECTOR(factory, const stop::CriterionFactory, criteria,
config, context, exec, td_for_child);
SET_POINTER(factory, const LinOpFactory, solver, config, context, exec,
td_for_child);

SET_POINTER(factory, const LinOp, generated_solver, config, context,
exec, td_for_child);
SET_VALUE(factory, ValueType, relaxation_factor, config);
SET_VALUE(factory, solver::initial_guess_mode, default_initial_guess,
config);
return factory.on(exec);
}
};


template <>
std::unique_ptr<gko::LinOpFactory> build_from_config<LinOpFactoryType::Ir>(
const pnode& config, const registry& context,
std::shared_ptr<const Executor>& exec, gko::config::type_descriptor td)
{
auto updated = update_type(config, td);
return dispatch<gko::LinOpFactory, IrConfigurator>(
updated.first, config, context, exec, updated, value_type_list());
}


template <typename ValueType>
class IdrConfigurator {
public:
static std::unique_ptr<typename solver::Idr<ValueType>::Factory>
build_from_config(const pnode& config, const registry& context,
std::shared_ptr<const Executor> exec,
type_descriptor td_for_child)
{
auto factory = solver::Idr<ValueType>::build();
common_solver_configure(factory, config, context, exec, td_for_child);
SET_VALUE(factory, size_type, subspace_dim, config);
SET_VALUE(factory, remove_complex<ValueType>, kappa, config);
SET_VALUE(factory, bool, deterministic, config);
SET_VALUE(factory, bool, complex_subspace, config);
return factory.on(exec);
}
};


template <>
std::unique_ptr<gko::LinOpFactory> build_from_config<LinOpFactoryType::Idr>(
const pnode& config, const registry& context,
std::shared_ptr<const Executor>& exec, gko::config::type_descriptor td)
{
auto updated = update_type(config, td);
return dispatch<gko::LinOpFactory, IdrConfigurator>(
updated.first, config, context, exec, updated, value_type_list());
}


template <typename ValueType>
class GcrConfigurator {
public:
static std::unique_ptr<typename solver::Gcr<ValueType>::Factory>
build_from_config(const pnode& config, const registry& context,
std::shared_ptr<const Executor> exec,
type_descriptor td_for_child)
{
auto factory = solver::Gcr<ValueType>::build();
common_solver_configure(factory, config, context, exec, td_for_child);
SET_VALUE(factory, size_type, krylov_dim, config);
return factory.on(exec);
}
};


template <>
std::unique_ptr<gko::LinOpFactory> build_from_config<LinOpFactoryType::Gcr>(
const pnode& config, const registry& context,
std::shared_ptr<const Executor>& exec, gko::config::type_descriptor td)
{
auto updated = update_type(config, td);
return dispatch<gko::LinOpFactory, GcrConfigurator>(
updated.first, config, context, exec, updated, value_type_list());
}


template <typename ValueType>
class GmresConfigurator {
public:
static std::unique_ptr<typename solver::Gmres<ValueType>::Factory>
build_from_config(const pnode& config, const registry& context,
std::shared_ptr<const Executor> exec,
type_descriptor td_for_child)
{
auto factory = solver::Gmres<ValueType>::build();
common_solver_configure(factory, config, context, exec, td_for_child);
SET_VALUE(factory, size_type, krylov_dim, config);
SET_VALUE(factory, bool, flexible, config);
return factory.on(exec);
}
};


template <>
std::unique_ptr<gko::LinOpFactory> build_from_config<LinOpFactoryType::Gmres>(
const pnode& config, const registry& context,
std::shared_ptr<const Executor>& exec, gko::config::type_descriptor td)
{
auto updated = update_type(config, td);
return dispatch<gko::LinOpFactory, GmresConfigurator>(
updated.first, config, context, exec, updated, value_type_list());
}


template <typename ValueType>
class CbGmresConfigurator {
public:
static std::unique_ptr<typename solver::CbGmres<ValueType>::Factory>
build_from_config(const pnode& config, const registry& context,
std::shared_ptr<const Executor> exec,
type_descriptor td_for_child)
{
auto factory = solver::CbGmres<ValueType>::build();
common_solver_configure(factory, config, context, exec, td_for_child);
SET_VALUE(factory, size_type, krylov_dim, config);
if (config.contains("storage_precision")) {
auto get_storage_precision = [](std::string str) {
using gko::solver::cb_gmres::storage_precision;
if (str == "keep") {
return storage_precision::keep;
} else if (str == "reduce1") {
return storage_precision::reduce1;
} else if (str == "reduce2") {
return storage_precision::reduce2;
} else if (str == "integer") {
return storage_precision::integer;
} else if (str == "ireduce1") {
return storage_precision::ireduce1;
} else if (str == "ireduce2") {
return storage_precision::ireduce2;
}
GKO_INVALID_STATE("Wrong value for storage_precision");
};
factory.with_storage_precision(get_storage_precision(
config.at("storage_precision").get_data<std::string>()));
}
return factory.on(exec);
}
};


template <>
std::unique_ptr<gko::LinOpFactory> build_from_config<LinOpFactoryType::CbGmres>(
const pnode& config, const registry& context,
std::shared_ptr<const Executor>& exec, gko::config::type_descriptor td)
{
auto updated = update_type(config, td);
return dispatch<gko::LinOpFactory, CbGmresConfigurator>(
updated.first, config, context, exec, updated, value_type_list());
}


template <typename ValueType, typename IndexType>
class DirectConfigurator {
public:
static std::unique_ptr<
typename experimental::solver::Direct<ValueType, IndexType>::Factory>
build_from_config(const pnode& config, const registry& context,
std::shared_ptr<const Executor> exec,
type_descriptor td_for_child)
{
auto factory =
experimental::solver::Direct<ValueType, IndexType>::build();
SET_VALUE(factory, size_type, num_rhs, config);
SET_POINTER(factory, const LinOpFactory, factorization, config, context,
exec, td_for_child);
return factory.on(exec);
}
};


template <>
std::unique_ptr<gko::LinOpFactory> build_from_config<LinOpFactoryType::Direct>(
const pnode& config, const registry& context,
std::shared_ptr<const Executor>& exec, gko::config::type_descriptor td)
{
auto updated = update_type(config, td);
return dispatch<gko::LinOpFactory, DirectConfigurator>(
updated.first + "," + updated.second, config, context, exec, updated,
value_type_list(), index_type_list());
}


template <template <class, class> class Trs>
class trs_helper {
public:
template <typename ValueType, typename IndexType>
class configurator {
public:
static std::unique_ptr<typename Trs<ValueType, IndexType>::Factory>
build_from_config(const pnode& config, const registry& context,
std::shared_ptr<const Executor> exec,
type_descriptor td_for_child)
{
auto factory = Trs<ValueType, IndexType>::build();
SET_VALUE(factory, size_type, num_rhs, config);
SET_VALUE(factory, bool, unit_diagonal, config);
if (config.contains("algorithm")) {
using gko::solver::trisolve_algorithm;
auto str = config.at("algorithm").get_data<std::string>();
if (str == "sparselib") {
factory.with_algorithm(trisolve_algorithm::sparselib);
} else if (str == "syncfree") {
factory.with_algorithm(trisolve_algorithm::syncfree);
} else {
GKO_INVALID_STATE("Wrong value for algorithm");
}
}
return factory.on(exec);
}
};
};


template <>
std::unique_ptr<gko::LinOpFactory>
build_from_config<LinOpFactoryType::LowerTrs>(
const pnode& config, const registry& context,
std::shared_ptr<const Executor>& exec, gko::config::type_descriptor td)
{
auto updated = update_type(config, td);
return dispatch<gko::LinOpFactory,
trs_helper<solver::LowerTrs>::configurator>(
updated.first + "," + updated.second, config, context, exec, updated,
value_type_list(), index_type_list());
}

template <>
std::unique_ptr<gko::LinOpFactory>
build_from_config<LinOpFactoryType::UpperTrs>(
const pnode& config, const registry& context,
std::shared_ptr<const Executor>& exec, gko::config::type_descriptor td)
{
auto updated = update_type(config, td);
return dispatch<gko::LinOpFactory,
trs_helper<solver::UpperTrs>::configurator>(
updated.first + "," + updated.second, config, context, exec, updated,
value_type_list(), index_type_list());
}


} // namespace config
} // namespace gko
Loading

0 comments on commit ad2ee29

Please sign in to comment.