Skip to content

Commit

Permalink
add history to convergence logger
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Dec 21, 2023
1 parent efa9021 commit 6b42557
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 48 deletions.
59 changes: 44 additions & 15 deletions core/log/convergence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,34 +57,55 @@ void Convergence<ValueType>::on_iteration_complete(
const LinOp* residual_norm, const LinOp* implicit_resnorm_sq,
const array<stopping_status>* status, const bool stopped) const
{
auto update_history = [&](auto& container, auto new_val, bool is_norm) {
if (history_ == convergence_history::none) {
if (container.empty()) {
container.emplace_back(nullptr);
}
container.back() = std::move(new_val);
return;
}
if (is_norm || history_ == convergence_history::full) {
container.emplace_back(std::move(new_val));
}
};
if (num_iterations == 0) {
residual_.clear();
residual_norm_.clear();
implicit_sq_resnorm_.clear();
}
if (stopped) {
array<stopping_status> tmp(status->get_executor()->get_master(),
*status);
this->convergence_status_ = true;
convergence_status_ = true;
for (int i = 0; i < status->get_size(); i++) {
if (!tmp.get_data()[i].has_converged()) {
this->convergence_status_ = false;
convergence_status_ = false;
break;
}
}
this->num_iterations_ = num_iterations;
num_iterations_ = num_iterations;
}
if (stopped || history_ != convergence_history::none) {
if (residual != nullptr) {
this->residual_.reset(residual->clone().release());
update_history(residual_, residual->clone(), false);
}
if (implicit_resnorm_sq != nullptr) {
this->implicit_sq_resnorm_.reset(
implicit_resnorm_sq->clone().release());
update_history(implicit_sq_resnorm_, implicit_resnorm_sq->clone(),
true);
}
if (residual_norm != nullptr) {
this->residual_norm_.reset(residual_norm->clone().release());
update_history(residual_norm_, residual_norm->clone(), true);
} else if (residual != nullptr) {
using NormVector = matrix::Dense<remove_complex<ValueType>>;
detail::vector_dispatch<ValueType>(
residual, [&](const auto* dense_r) {
this->residual_norm_ =
update_history(
residual_norm_,
NormVector::create(residual->get_executor(),
dim<2>{1, residual->get_size()[1]});
dense_r->compute_norm2(this->residual_norm_);
dim<2>{1, residual->get_size()[1]}),
true);
dense_r->compute_norm2(residual_norm_.back());
});
} else if (dynamic_cast<const solver::detail::SolverBaseLinOp*>(
solver) &&
Expand All @@ -97,13 +118,21 @@ void Convergence<ValueType>::on_iteration_complete(
detail::vector_dispatch<ValueType>(b, [&](const auto* dense_b) {
detail::vector_dispatch<ValueType>(x, [&](const auto* dense_x) {
auto exec = system_mtx->get_executor();
auto residual = dense_b->clone();
this->residual_norm_ = NormVector::create(
exec, dim<2>{1, residual->get_size()[1]});
update_history(residual_, dense_b->clone(), false);
system_mtx->apply(initialize<Vector>({-1.0}, exec), dense_x,
initialize<Vector>({1.0}, exec),
residual);
residual->compute_norm2(this->residual_norm_);
residual_.back());
update_history(
residual_norm_,
NormVector::create(
exec, dim<2>{1, residual_.back()->get_size()[1]}),
true);
detail::vector_dispatch<ValueType>(
residual_.back().get(),
[&](const auto* actual_residual) {
actual_residual->compute_norm2(
residual_norm_.back());
});
});
});
}
Expand Down
92 changes: 92 additions & 0 deletions core/test/log/convergence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ TYPED_TEST(Convergence, CanGetEmptyData)
ASSERT_EQ(logger->get_residual(), nullptr);
ASSERT_EQ(logger->get_residual_norm(), nullptr);
ASSERT_EQ(logger->get_implicit_sq_resnorm(), nullptr);
ASSERT_TRUE(logger->get_residual_history().empty());
ASSERT_TRUE(logger->get_residual_norm_history().empty());
ASSERT_TRUE(logger->get_implicit_sq_resnorm_history().empty());
}


Expand Down Expand Up @@ -100,6 +103,10 @@ TYPED_TEST(Convergence, DoesNotLogIfNotStopped)
ASSERT_EQ(logger->get_num_iterations(), 0);
ASSERT_EQ(logger->get_residual(), nullptr);
ASSERT_EQ(logger->get_residual_norm(), nullptr);
ASSERT_EQ(logger->get_implicit_sq_resnorm(), nullptr);
ASSERT_TRUE(logger->get_residual_history().empty());
ASSERT_TRUE(logger->get_residual_norm_history().empty());
ASSERT_TRUE(logger->get_implicit_sq_resnorm_history().empty());
}


Expand Down Expand Up @@ -131,4 +138,89 @@ TYPED_TEST(Convergence, CanComputeResidualNormFromSolution)
}


TYPED_TEST(Convergence, CanLogDataWithNormHistory)
{
using AbsoluteDense = gko::matrix::Dense<gko::remove_complex<TypeParam>>;
auto logger = gko::log::Convergence<TypeParam>::create(
gko::convergence_history::norm);

logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 100, nullptr,
this->residual_norm.get(), this->implicit_sq_resnorm.get(), nullptr,
false);
logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 101, nullptr,
this->residual_norm.get(), this->implicit_sq_resnorm.get(),
&this->status, true);

ASSERT_EQ(logger->get_residual_history().size(), 0);
ASSERT_EQ(logger->get_residual_norm_history().size(), 2);
ASSERT_EQ(logger->get_implicit_sq_resnorm_history().size(), 2);
for (int i : {0, 1}) {
GKO_ASSERT_MTX_NEAR(
gko::as<AbsoluteDense>(logger->get_residual_norm_history()[i]),
this->residual_norm, 0);
GKO_ASSERT_MTX_NEAR(gko::as<AbsoluteDense>(
logger->get_implicit_sq_resnorm_history()[i]),
this->implicit_sq_resnorm, 0);
}
}


TYPED_TEST(Convergence, CanLogDataWithFullHistory)
{
using Dense = gko::matrix::Dense<TypeParam>;
using AbsoluteDense = gko::matrix::Dense<gko::remove_complex<TypeParam>>;
auto logger = gko::log::Convergence<TypeParam>::create(
gko::convergence_history::full);

logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 100, this->residual.get(),
this->residual_norm.get(), this->implicit_sq_resnorm.get(), nullptr,
false);
logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 101, this->residual.get(),
this->residual_norm.get(), this->implicit_sq_resnorm.get(),
&this->status, true);

ASSERT_EQ(logger->get_residual_history().size(), 2);
ASSERT_EQ(logger->get_residual_norm_history().size(), 2);
ASSERT_EQ(logger->get_implicit_sq_resnorm_history().size(), 2);
for (int i : {0, 1}) {
GKO_ASSERT_MTX_NEAR(gko::as<Dense>(logger->get_residual_history()[i]),
this->residual, 0);
GKO_ASSERT_MTX_NEAR(
gko::as<AbsoluteDense>(logger->get_residual_norm_history()[i]),
this->residual_norm, 0);
GKO_ASSERT_MTX_NEAR(gko::as<AbsoluteDense>(
logger->get_implicit_sq_resnorm_history()[i]),
this->implicit_sq_resnorm, 0);
}
}


TYPED_TEST(Convergence, CanClearHistory)
{
auto logger = gko::log::Convergence<TypeParam>::create(
gko::convergence_history::full);

logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 100, this->residual.get(),
this->residual_norm.get(), this->implicit_sq_resnorm.get(), nullptr,
false);
logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 101, this->residual.get(),
this->residual_norm.get(), this->implicit_sq_resnorm.get(),
&this->status, true);
logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 0, this->residual.get(),
this->residual_norm.get(), this->implicit_sq_resnorm.get(), nullptr,
false);

ASSERT_EQ(logger->get_residual_history().size(), 1);
ASSERT_EQ(logger->get_residual_norm_history().size(), 1);
ASSERT_EQ(logger->get_implicit_sq_resnorm_history().size(), 1);
}


} // namespace
Loading

0 comments on commit 6b42557

Please sign in to comment.