Skip to content

Commit

Permalink
Rename orthog_method to ortho_method
Browse files Browse the repository at this point in the history
  • Loading branch information
nbeams committed Aug 8, 2024
1 parent fabb08c commit 332235c
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 42 deletions.
35 changes: 17 additions & 18 deletions core/solver/gmres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ GKO_REGISTER_OPERATION(multi_dot, gmres::multi_dot);
} // anonymous namespace


std::ostream& operator<<(std::ostream& stream, orthog_method orthog)
std::ostream& operator<<(std::ostream& stream, ortho_method ortho)
{
switch (orthog) {
case orthog_method::mgs:
switch (ortho) {
case ortho_method::mgs:
return stream << "mgs";
case orthog_method::cgs:
case ortho_method::cgs:
return stream << "cgs";
case orthog_method::cgs2:
case ortho_method::cgs2:
return stream << "cgs2";
}
return stream;
Expand All @@ -69,19 +69,19 @@ typename Gmres<ValueType>::parameters_type Gmres<ValueType>::parse(
if (auto& obj = config.get("flexible")) {
params.with_flexible(gko::config::get_value<bool>(obj));
}
if (auto& obj = config.get("orthog_method")) {
if (auto& obj = config.get("ortho_method")) {
auto str = obj.get_string();
gmres::orthog_method orthog;
gmres::ortho_method ortho;
if (str == "mgs") {
orthog = gmres::orthog_method::mgs;
ortho = gmres::ortho_method::mgs;
} else if (str == "cgs") {
orthog = gmres::orthog_method::cgs;
ortho = gmres::ortho_method::cgs;
} else if (str == "cgs2") {
orthog = gmres::orthog_method::cgs2;
ortho = gmres::ortho_method::cgs2;
} else {
GKO_INVALID_CONFIG_VALUE("orthog_method", str);
GKO_INVALID_CONFIG_VALUE("ortho_method", str);
}
params.with_orthog_method(orthog);
params.with_ortho_method(ortho);
}
return params;
}
Expand Down Expand Up @@ -361,7 +361,7 @@ void Gmres<ValueType>::apply_dense_impl(const VectorType* dense_b,
// iteration of data at a time, we store it in the "logical" layout
// from the start.
LocalVector* hessenberg_aux = nullptr;
if (this->parameters_.orthog_method == gmres::orthog_method::cgs2) {
if (this->parameters_.ortho_method == gmres::ortho_method::cgs2) {
hessenberg_aux = this->template create_workspace_op<LocalVector>(
ws::hessenberg_aux, dim<2>{(krylov_dim + 1), num_rhs});
}
Expand Down Expand Up @@ -528,17 +528,16 @@ void Gmres<ValueType>::apply_dense_impl(const VectorType* dense_b,
// next_krylov = A * preconditioned_krylov_vector
this->get_system_matrix()->apply(preconditioned_krylov_vector,
next_krylov);
if (this->parameters_.orthog_method == gmres::orthog_method::mgs) {
if (this->parameters_.ortho_method == gmres::ortho_method::mgs) {
orthogonalize_mgs(hessenberg_iter.get(), krylov_bases,
next_krylov.get(), reduction_tmp, restart_iter,
num_rows, num_rhs, local_num_rows);
} else if (this->parameters_.orthog_method ==
gmres::orthog_method::cgs) {
} else if (this->parameters_.ortho_method == gmres::ortho_method::cgs) {
orthogonalize_cgs(hessenberg_iter.get(), krylov_bases,
next_krylov.get(), restart_iter, num_rows,
num_rhs, local_num_rows);
} else if (this->parameters_.orthog_method ==
gmres::orthog_method::cgs2) {
} else if (this->parameters_.ortho_method ==
gmres::ortho_method::cgs2) {
orthogonalize_cgs2(hessenberg_iter.get(), krylov_bases,
next_krylov.get(), hessenberg_aux, one_op,
restart_iter, num_rows, num_rhs, local_num_rows);
Expand Down
6 changes: 3 additions & 3 deletions core/test/config/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ struct Gmres
param.with_krylov_dim(3u);
config_map["flexible"] = pnode{true};
param.with_flexible(true);
config_map["orthog_method"] = pnode{"cgs"};
param.with_orthog_method(gko::solver::gmres::orthog_method::cgs);
config_map["ortho_method"] = pnode{"cgs"};
param.with_ortho_method(gko::solver::gmres::ortho_method::cgs);
}

template <bool from_reg, typename AnswerType>
Expand All @@ -302,7 +302,7 @@ struct Gmres
solver_config_test::template validate<from_reg>(result, answer);
ASSERT_EQ(res_param.krylov_dim, ans_param.krylov_dim);
ASSERT_EQ(res_param.flexible, ans_param.flexible);
ASSERT_EQ(res_param.orthog_method, ans_param.orthog_method);
ASSERT_EQ(res_param.ortho_method, ans_param.ortho_method);
}
};

Expand Down
8 changes: 4 additions & 4 deletions include/ginkgo/core/solver/gmres.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace gmres {
/**
* Set the orthogonalization method for the Krylov subspace.
*/
enum class orthog_method {
enum class ortho_method {
/**
* Modified Gram-Schmidt (default)
*/
Expand All @@ -51,7 +51,7 @@ enum class orthog_method {
};

/** Prints an orthogonalization method. */
std::ostream& operator<<(std::ostream& stream, orthog_method orthog);
std::ostream& operator<<(std::ostream& stream, ortho_method ortho);

} // namespace gmres

Expand Down Expand Up @@ -118,8 +118,8 @@ class Gmres
bool GKO_FACTORY_PARAMETER_SCALAR(flexible, false);

/** Orthogonalization method */
gmres::orthog_method GKO_FACTORY_PARAMETER_SCALAR(
orthog_method, gmres::orthog_method::mgs);
gmres::ortho_method GKO_FACTORY_PARAMETER_SCALAR(
ortho_method, gmres::ortho_method::mgs);
};
GKO_ENABLE_LIN_OP_FACTORY(Gmres, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);
Expand Down
10 changes: 5 additions & 5 deletions reference/test/solver/gmres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,17 +754,17 @@ TYPED_TEST(Gmres, SolvesBigDenseSystem1WithRestart)

TYPED_TEST(Gmres, SolvesWithPreconditioner)
{
using gko::solver::gmres::orthog_method;
using gko::solver::gmres::ortho_method;

using Mtx = typename TestFixture::Mtx;
using Solver = typename TestFixture::Solver;
using value_type = typename TestFixture::value_type;
for (auto orthog :
{orthog_method::mgs, orthog_method::cgs, orthog_method::cgs2}) {
SCOPED_TRACE(orthog);
for (auto ortho :
{ortho_method::mgs, ortho_method::cgs, ortho_method::cgs2}) {
SCOPED_TRACE(ortho);
auto gmres_factory_preconditioner =
Solver::build()
.with_orthog_method(orthog)
.with_ortho_method(ortho)
.with_criteria(
gko::stop::Iteration::build().with_max_iters(100u),
gko::stop::ResidualNorm<value_type>::build()
Expand Down
12 changes: 6 additions & 6 deletions test/mpi/solver/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,14 @@ struct Ir : SimpleSolverTest<gko::solver::Ir<solver_value_type>> {
};


template <unsigned dimension, gko::solver::gmres::orthog_method orthog>
template <unsigned dimension, gko::solver::gmres::ortho_method ortho>
struct Gmres : SimpleSolverTest<gko::solver::Gmres<solver_value_type>> {
static typename solver_type::parameters_type build(
std::shared_ptr<const gko::Executor> exec)
{
return SimpleSolverTest<gko::solver::Gmres<solver_value_type>>::build(
std::move(exec))
.with_orthog_method(orthog)
.with_ortho_method(ortho)
.with_krylov_dim(dimension);
}
};
Expand Down Expand Up @@ -532,10 +532,10 @@ class Solver : public CommonMpiTestFixture {

using SolverTypes =
::testing::Types<Cg, CgWithMg, Cgs, Fcg, Bicgstab, Ir, Gcr<10u>, Gcr<100u>,
Gmres<10u, gko::solver::gmres::orthog_method::mgs>,
Gmres<10u, gko::solver::gmres::orthog_method::cgs>,
Gmres<10u, gko::solver::gmres::orthog_method::cgs2>,
Gmres<100u, gko::solver::gmres::orthog_method::mgs>>;
Gmres<10u, gko::solver::gmres::ortho_method::mgs>,
Gmres<10u, gko::solver::gmres::ortho_method::cgs>,
Gmres<10u, gko::solver::gmres::ortho_method::cgs2>,
Gmres<100u, gko::solver::gmres::ortho_method::mgs>>;

TYPED_TEST_SUITE(Solver, SolverTypes, TypenameNameGenerator);

Expand Down
12 changes: 6 additions & 6 deletions test/solver/gmres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,18 +327,18 @@ TEST_F(Gmres, GmresApplyOneRHSIsEquivalentToRef)

TEST_F(Gmres, GmresApplyMultipleRHSIsEquivalentToRef)
{
using gko::solver::gmres::orthog_method;
using gko::solver::gmres::ortho_method;
auto base_params = gko::clone(ref, ref_gmres_factory)->get_parameters();

for (auto orthog :
{orthog_method::mgs, orthog_method::cgs, orthog_method::cgs2}) {
SCOPED_TRACE(orthog);
for (auto ortho :
{ortho_method::mgs, ortho_method::cgs, ortho_method::cgs2}) {
SCOPED_TRACE(ortho);
int m = 123;
int n = 5;
auto ref_solver =
base_params.with_orthog_method(orthog).on(ref)->generate(mtx);
base_params.with_ortho_method(ortho).on(ref)->generate(mtx);
auto exec_solver =
base_params.with_orthog_method(orthog).on(exec)->generate(d_mtx);
base_params.with_ortho_method(ortho).on(exec)->generate(d_mtx);
auto b = gen_mtx(m, n);
auto x = gen_mtx(m, n);
auto d_b = gko::clone(exec, b);
Expand Down

0 comments on commit 332235c

Please sign in to comment.