Skip to content

Commit

Permalink
use iteration from stop criterion and update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Aug 3, 2023
1 parent 197acd4 commit 7df1d42
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 70 deletions.
119 changes: 90 additions & 29 deletions core/solver/chebyshev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,37 @@ GKO_REGISTER_OPERATION(initialize, ir::initialize);
} // namespace chebyshev


template <typename ValueType>
Chebyshev<ValueType>::Chebyshev(const Factory* factory,
std::shared_ptr<const LinOp> system_matrix)
: EnableLinOp<Chebyshev>(factory->get_executor(),
gko::transpose(system_matrix->get_size())),
EnableSolverBase<Chebyshev>{std::move(system_matrix)},
EnableIterativeBase<Chebyshev>{
stop::combine(factory->get_parameters().criteria)},
parameters_{factory->get_parameters()}
{
if (parameters_.generated_solver) {
this->set_solver(parameters_.generated_solver);
} else if (parameters_.solver) {
this->set_solver(
parameters_.solver->generate(this->get_system_matrix()));
} else {
this->set_solver(matrix::Identity<ValueType>::create(
this->get_executor(), this->get_size()));
}
this->set_default_initial_guess(parameters_.default_initial_guess);
center_ = (std::get<0>(parameters_.foci) + std::get<1>(parameters_.foci)) /
ValueType{2};
foci_direction_ =
(std::get<1>(parameters_.foci) - std::get<0>(parameters_.foci)) /
ValueType{2};
// if changing the lower/upper eig, need to reset it to zero
num_generated_scalar_ = 0;
num_max_generation_ = 3;
}


template <typename ValueType>
void Chebyshev<ValueType>::set_solver(std::shared_ptr<const LinOp> new_solver)
{
Expand Down Expand Up @@ -185,12 +216,29 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
GKO_SOLVER_VECTOR(residual, dense_b);
GKO_SOLVER_VECTOR(inner_solution, dense_b);
GKO_SOLVER_VECTOR(update_solution, dense_b);

// Use the scalar first
auto num_keep = this->get_parameters().num_keep;
// get the iteration information from stopping criterion.
if (auto combined =
std::dynamic_pointer_cast<const gko::stop::Combined::Factory>(
this->get_stop_criterion_factory())) {
for (const auto& factory : combined->get_parameters().criteria) {
if (auto iter_stop = std::dynamic_pointer_cast<
const gko::stop::Iteration::Factory>(factory)) {
num_max_generation_ = std::max(
num_max_generation_, iter_stop->get_parameters().max_iters);
}
}
} else if (auto iter_stop = std::dynamic_pointer_cast<
const gko::stop::Iteration::Factory>(
this->get_stop_criterion_factory())) {
num_max_generation_ = std::max(num_max_generation_,
iter_stop->get_parameters().max_iters);
}
auto alpha = this->template create_workspace_scalar<ValueType>(
GKO_SOLVER_TRAITS::alpha, num_keep + 1);
GKO_SOLVER_TRAITS::alpha, num_max_generation_ + 1);
auto beta = this->template create_workspace_scalar<ValueType>(
GKO_SOLVER_TRAITS::beta, num_keep + 1);
GKO_SOLVER_TRAITS::beta, num_max_generation_ + 1);

GKO_SOLVER_ONE_MINUS_ONE();

Expand Down Expand Up @@ -218,39 +266,50 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
int iter = -1;
while (true) {
++iter;
this->template log<log::Logger::iteration_complete>(
this, iter, residual_ptr, dense_x);

if (iter == 0) {
// In iter 0, the iteration and residual are updated.
if (stop_criterion->update()
.num_iterations(iter)
.residual(residual_ptr)
.solution(dense_x)
.check(relative_stopping_id, true, &stop_status,
&one_changed)) {
bool all_stopped = stop_criterion->update()
.num_iterations(iter)
.residual(residual_ptr)
.solution(dense_x)
.check(relative_stopping_id, true,
&stop_status, &one_changed);
this->template log<log::Logger::iteration_complete>(
this, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr,
&stop_status, all_stopped);
if (all_stopped) {
break;
}
} else {
// In the other iterations, the residual can be updated separately.
if (stop_criterion->update()
.num_iterations(iter)
.solution(dense_x)
.check(relative_stopping_id, false, &stop_status,
&one_changed)) {
bool all_stopped = stop_criterion->update()
.num_iterations(iter)
.solution(dense_x)
// we have the residual check later
.ignore_residual_check(true)
.check(relative_stopping_id, false,
&stop_status, &one_changed);
if (all_stopped) {
this->template log<log::Logger::iteration_complete>(
this, dense_b, dense_x, iter, nullptr, nullptr, nullptr,
&stop_status, all_stopped);
break;
}
residual_ptr = residual;
// residual = b - A * x
residual->copy_from(dense_b);
this->get_system_matrix()->apply(neg_one_op, dense_x, one_op,
residual);
if (stop_criterion->update()
.num_iterations(iter)
.residual(residual_ptr)
.solution(dense_x)
.check(relative_stopping_id, true, &stop_status,
&one_changed)) {
all_stopped = stop_criterion->update()
.num_iterations(iter)
.residual(residual_ptr)
.solution(dense_x)
.check(relative_stopping_id, true, &stop_status,
&one_changed);
this->template log<log::Logger::iteration_complete>(
this, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr,
&stop_status, all_stopped);
if (all_stopped) {
break;
}
}
Expand All @@ -262,17 +321,18 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
inner_solution->copy_from(residual_ptr);
}
solver_->apply(residual_ptr, inner_solution);
size_type index = (iter >= num_keep) ? num_keep : iter;
size_type index =
(iter >= num_max_generation_) ? num_max_generation_ : iter;
auto alpha_scalar =
alpha->create_submatrix(span{0, 1}, span{index, index + 1});
auto beta_scalar =
beta->create_submatrix(span{0, 1}, span{index, index + 1});
if (iter == 0) {
if (num_generated_ < num_keep) {
if (num_generated_scalar_ < num_max_generation_) {
alpha_scalar->fill(alpha_ref);
// unused beta for first iteration, but fill zero
beta_scalar->fill(zero<ValueType>());
num_generated_++;
num_generated_scalar_++;
}
// x = x + alpha * inner_solution
dense_x->add_scaled(alpha_scalar.get(), inner_solution);
Expand All @@ -286,12 +346,13 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
}
alpha_ref = ValueType{1.0} / (center_ - beta_ref / alpha_ref);
// The last one is always the updated one
if (num_generated_ < num_keep || iter >= num_keep) {
if (num_generated_scalar_ < num_max_generation_ ||
iter >= num_max_generation_) {
alpha_scalar->fill(alpha_ref);
beta_scalar->fill(beta_ref);
}
if (num_generated_ < num_keep) {
num_generated_++;
if (num_generated_scalar_ < num_max_generation_) {
num_generated_scalar_++;
}
// z = z + beta * p
// p = z
Expand Down
52 changes: 14 additions & 38 deletions include/ginkgo/core/solver/chebyshev.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ namespace solver {


/**
* Chebyshev iteration is an iterative method that uses another coarse
* method to approximate the error of the current solution via the current
* Chebyshev iteration is an iterative method that uses another inner
* solver to approximate the error of the current solution via the current
* residual. It has another term for the difference of solution. Moreover, this
* method requires knowledge about the spectrum of the matrix.
* method requires knowledge about the spectrum of the matrix. This
* implementation follows the algorithm in "Templates for the Solution of Linear
* Systems: Building Blocks for Iterative Methods, 2nd Edition".
*
* ```
* solution = initial_guess
Expand Down Expand Up @@ -156,7 +158,8 @@ class Chebyshev : public EnableLinOp<Chebyshev<ValueType>>,
GKO_FACTORY_PARAMETER_VECTOR(criteria, nullptr);

/**
* Inner solver factory.
* Inner solver factory. If not provided this will result in a
* non-preconditioned Chebyshev iteration.
*/
std::shared_ptr<const LinOpFactory> GKO_FACTORY_PARAMETER_SCALAR(
solver, nullptr);
Expand All @@ -181,11 +184,6 @@ class Chebyshev : public EnableLinOp<Chebyshev<ValueType>>,
*/
initial_guess_mode GKO_FACTORY_PARAMETER_SCALAR(
default_initial_guess, initial_guess_mode::provided);

/**
* The number of scalar to keep
*/
int GKO_FACTORY_PARAMETER_SCALAR(num_keep, 2);
};
GKO_ENABLE_LIN_OP_FACTORY(Chebyshev, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);
Expand Down Expand Up @@ -215,38 +213,16 @@ class Chebyshev : public EnableLinOp<Chebyshev<ValueType>>,
{}

explicit Chebyshev(const Factory* factory,
std::shared_ptr<const LinOp> system_matrix)
: EnableLinOp<Chebyshev>(factory->get_executor(),
gko::transpose(system_matrix->get_size())),
EnableSolverBase<Chebyshev>{std::move(system_matrix)},
EnableIterativeBase<Chebyshev>{
stop::combine(factory->get_parameters().criteria)},
parameters_{factory->get_parameters()}
{
if (parameters_.generated_solver) {
this->set_solver(parameters_.generated_solver);
} else if (parameters_.solver) {
this->set_solver(
parameters_.solver->generate(this->get_system_matrix()));
} else {
this->set_solver(matrix::Identity<ValueType>::create(
this->get_executor(), this->get_size()));
}
this->set_default_initial_guess(parameters_.default_initial_guess);
center_ =
(std::get<0>(parameters_.foci) + std::get<1>(parameters_.foci)) /
ValueType{2};
// the absolute value of foci_direction is the focal direction
foci_direction_ =
(std::get<1>(parameters_.foci) - std::get<0>(parameters_.foci)) /
ValueType{2};
// if changing the lower/upper eig, need to reset it to zero
num_generated_ = 0;
}
std::shared_ptr<const LinOp> system_matrix);

private:
std::shared_ptr<const LinOp> solver_{};
mutable int num_generated_;
// num_generated_scalar_ is to track the number of generated scalar alpha
// and beta.
mutable size_type num_generated_scalar_;
// num_max_generation_ is the number of keeping the generated scalar in
// workspace.
mutable size_type num_max_generation_;
ValueType center_;
ValueType foci_direction_;
};
Expand Down
Loading

0 comments on commit 7df1d42

Please sign in to comment.