Skip to content

Commit

Permalink
Move allocs to schwarz generate
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Aug 6, 2024
1 parent d401e89 commit baa9619
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 37 deletions.
45 changes: 25 additions & 20 deletions core/distributed/preconditioner/schwarz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,11 @@ void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::apply_dense_impl(
if (this->coarse_solver_ != nullptr && this->galerkin_ops_ != nullptr) {
auto restrict = this->galerkin_ops_->get_restrict_op();
auto prolong = this->galerkin_ops_->get_prolong_op();
auto coarse =
as<experimental::distributed::Matrix<ValueType, LocalIndexType,
GlobalIndexType>>(
this->galerkin_ops_->get_coarse_op());
auto comm = coarse->get_communicator();

auto cs_ncols = dense_x->get_size()[1];
auto cs_local_nrows = coarse->get_local_matrix()->get_size()[0];
auto cs_global_nrows = coarse->get_size()[0];
auto cs_local_size = dim<2>(cs_local_nrows, cs_ncols);
auto cs_global_size = dim<2>(cs_global_nrows, cs_ncols);
auto csol = dist_vec::create(exec, comm, cs_global_size, cs_local_size,
dense_x->get_stride());
restrict->apply(dense_b, csol);
auto tmp = csol->clone();
this->coarse_solver_->apply(csol, tmp);
auto one = gko::initialize<Vector>({0.5}, exec);
auto zero = gko::initialize<Vector>({0.5}, exec);
prolong->apply(one, tmp, zero, dense_x);
restrict->apply(dense_b, this->csol_);
this->coarse_solver_->apply(this->csol_, this->csol_);
prolong->apply(this->half_.get(), this->csol_.get(), this->half_.get(),
dense_x);
}
}

Expand Down Expand Up @@ -111,6 +97,8 @@ template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::generate(
std::shared_ptr<const LinOp> system_matrix)
{
using Vector = matrix::Dense<ValueType>;
using dist_vec = experimental::distributed::Vector<ValueType>;
if (parameters_.local_solver && parameters_.generated_local_solver) {
GKO_INVALID_STATE(
"Provided both a generated solver and a solver factory");
Expand All @@ -135,8 +123,25 @@ void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::generate(
if (parameters_.galerkin_ops_factory && parameters_.coarse_solver_factory) {
this->galerkin_ops_ = as<multigrid::MultigridLevel>(
share(parameters_.galerkin_ops_factory->generate(dist_mat)));
this->coarse_solver_ = parameters_.coarse_solver_factory->generate(
this->galerkin_ops_->get_coarse_op());
auto coarse =
as<experimental::distributed::Matrix<ValueType, LocalIndexType,
GlobalIndexType>>(
this->galerkin_ops_->get_coarse_op());
auto exec = coarse->get_executor();
auto comm = coarse->get_communicator();
this->coarse_solver_ =
parameters_.coarse_solver_factory->generate(coarse);
// TODO: Set correct rhs and stride.
auto cs_ncols = 1; // dense_x->get_size()[1];
auto cs_local_nrows = coarse->get_local_matrix()->get_size()[0];
auto cs_global_nrows = coarse->get_size()[0];
auto cs_local_size = dim<2>(cs_local_nrows, cs_ncols);
auto cs_global_size = dim<2>(cs_global_nrows, cs_ncols);
this->csol_ = gko::share(dist_vec::create(exec, comm, cs_global_size,
cs_local_size,
1 /*dense_x->get_stride()*/));
// this->temp_ = this->csol->clone();
this->half_ = gko::share(gko::initialize<Vector>({0.5}, exec));
}
}

Expand Down
58 changes: 41 additions & 17 deletions examples/distributed-solver/distributed-solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ int main(int argc, char* argv[])
static_cast<gko::size_type>(argc >= 3 ? std::atoi(argv[2]) : 100);
const auto num_iters =
static_cast<gko::size_type>(argc >= 4 ? std::atoi(argv[3]) : 1000);
std::string schw_type = argc >= 5 ? argv[4] : "multi-level";

const std::map<std::string,
std::function<std::shared_ptr<gko::Executor>(MPI_Comm)>>
Expand Down Expand Up @@ -194,31 +195,53 @@ int main(int argc, char* argv[])
.with_criteria(
gko::stop::Iteration::build().with_max_iters(100).on(exec),
gko::stop::ResidualNorm<ValueType>::build()
.with_reduction_factor(1e-6)
.with_reduction_factor(1e-3)
.on(exec))
.on(exec));

auto pgm_fac = gko::share(pgm::build().on(exec));

// Setup the stopping criterion and logger
const gko::remove_complex<ValueType> reduction_factor{1e-8};
std::shared_ptr<const gko::log::Convergence<ValueType>> logger =
gko::log::Convergence<ValueType>::create();
auto Ainv =
solver::build()
.with_preconditioner(
schwarz::build()
.with_local_solver_factory(local_solver)
// .with_galerkin_ops_factory(pgm_fac)
// .with_coarse_solver_factory(coarse_solver)
.on(exec))
.with_criteria(
gko::stop::Iteration::build().with_max_iters(num_iters).on(
exec),
gko::stop::ResidualNorm<ValueType>::build()
.with_reduction_factor(reduction_factor)
.on(exec))
.on(exec)
->generate(A);
std::shared_ptr<gko::LinOp> Ainv{};
if (schw_type == "multi-level") {
Ainv =
solver::build()
.with_preconditioner(
schwarz::build()
.with_local_solver_factory(local_solver)
.with_galerkin_ops_factory(pgm_fac)
.with_coarse_solver_factory(coarse_solver)
.on(exec))
.with_criteria(
gko::stop::Iteration::build().with_max_iters(num_iters).on(
exec),
gko::stop::ResidualNorm<ValueType>::build()
.with_reduction_factor(reduction_factor)
.on(exec))
.on(exec)
->generate(A);
} else {
schw_type = "one-level";
Ainv =
solver::build()
.with_preconditioner(
schwarz::build()
.with_local_solver_factory(local_solver)
.with_galerkin_ops_factory(pgm_fac)
.with_coarse_solver_factory(coarse_solver)
.on(exec))
.with_criteria(
gko::stop::Iteration::build().with_max_iters(num_iters).on(
exec),
gko::stop::ResidualNorm<ValueType>::build()
.with_reduction_factor(reduction_factor)
.on(exec))
.on(exec)
->generate(A);
}
// Add logger to the generated solver to log the iteration count and
// residual norm
Ainv->add_logger(logger);
Expand All @@ -245,6 +268,7 @@ int main(int argc, char* argv[])
// clang-format off
std::cout << "\nNum rows in matrix: " << num_rows
<< "\nNum ranks: " << comm.size()
<< "\nPrecond type: " << schw_type
<< "\nFinal Res norm: " << *host_res->get_const_values()
<< "\nIteration count: " << logger->get_num_iterations()
<< "\nInit time: " << t_init_end - t_init
Expand Down
2 changes: 2 additions & 0 deletions include/ginkgo/core/distributed/preconditioner/schwarz.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class Schwarz
std::shared_ptr<const LinOp> local_solver_;
std::shared_ptr<const multigrid::MultigridLevel> galerkin_ops_;
std::shared_ptr<const LinOp> coarse_solver_;
std::shared_ptr<LinOp> csol_;
std::shared_ptr<const LinOp> half_;
};


Expand Down

0 comments on commit baa9619

Please sign in to comment.