-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
286 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
// | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
|
||
#include <ginkgo/core/log/solver_debug.hpp> | ||
|
||
|
||
#include <iomanip> | ||
|
||
|
||
#include <ginkgo/core/base/exception_helpers.hpp> | ||
#include <ginkgo/core/base/name_demangling.hpp> | ||
#include <ginkgo/core/log/logger.hpp> | ||
#include <ginkgo/core/solver/solver_base.hpp> | ||
|
||
|
||
namespace gko { | ||
namespace log { | ||
|
||
|
||
static void print_scalar(const LinOp* value, std::ostream& stream) | ||
{ | ||
using conv_to_double = ConvertibleTo<matrix::Dense<double>>; | ||
using conv_to_complex = ConvertibleTo<matrix::Dense<std::complex<double>>>; | ||
const auto host_exec = value->get_executor()->get_master(); | ||
if (value->get_size()[0] == 0) { | ||
stream << "<empty>"; | ||
} else if (value->get_size()[0] != 1) { | ||
stream << "<matrix>"; | ||
} else if (dynamic_cast<const conv_to_double*>(value)) { | ||
auto host_value = matrix::Dense<double>::create(host_exec); | ||
host_value->copy_from(value); | ||
stream << host_value->at(0, 0); | ||
} else if (dynamic_cast<const conv_to_complex*>(value)) { | ||
auto host_value = | ||
matrix::Dense<std::complex<double>>::create(host_exec); | ||
host_value->copy_from(value); | ||
stream << host_value->at(0, 0); | ||
} else { | ||
stream << "<unknown>"; | ||
} | ||
} | ||
|
||
|
||
void SolverDebug::on_linop_apply_started(const LinOp* solver, const LinOp* in, | ||
const LinOp* out) const | ||
{ | ||
using solver_base = solver::detail::SolverBaseLinOp; | ||
auto dynamic_type = name_demangling::get_dynamic_type(*solver); | ||
auto& stream = *output_; | ||
stream << dynamic_type << "::apply(" << in << ',' << out | ||
<< ") of dimensions " << solver->get_size() << " and " | ||
<< in->get_size()[1] << " rhs\n"; | ||
if (const auto base = dynamic_cast<const solver_base*>(solver)) { | ||
const auto scalars = base->get_workspace_scalars(); | ||
const auto names = base->get_workspace_op_names(); | ||
stream << std::setw(column_width_) << "Iteration"; | ||
for (auto scalar : scalars) { | ||
stream << std::setw(column_width_) << names[scalar]; | ||
} | ||
stream << '\n'; | ||
} else { | ||
stream << "This solver type is not supported by the SolverDebug logger"; | ||
} | ||
} | ||
|
||
|
||
void SolverDebug::on_iteration_complete( | ||
const LinOp* solver, const LinOp* right_hand_side, const LinOp* solution, | ||
const size_type& num_iterations, const LinOp* residual, | ||
const LinOp* residual_norm, const LinOp* implicit_sq_residual_norm, | ||
const array<stopping_status>* status, bool stopped) const | ||
{ | ||
using solver_base = solver::detail::SolverBaseLinOp; | ||
auto& stream = *output_; | ||
stream << std::setprecision(precision_); | ||
if (const auto base = dynamic_cast<const solver_base*>(solver)) { | ||
const auto scalars = base->get_workspace_scalars(); | ||
stream << std::setw(column_width_) << num_iterations; | ||
for (auto scalar : scalars) { | ||
stream << std::setw(column_width_); | ||
print_scalar(base->get_workspace_op(scalar), stream); | ||
} | ||
stream << '\n'; | ||
} | ||
} | ||
|
||
|
||
void SolverDebug::on_iteration_complete(const LinOp* solver, | ||
const size_type& num_iterations, | ||
const LinOp* residual, | ||
const LinOp* solution, | ||
const LinOp* residual_norm) const | ||
{ | ||
on_iteration_complete(solver, nullptr, solution, num_iterations, residual, | ||
residual_norm, nullptr, nullptr, false); | ||
} | ||
|
||
|
||
void SolverDebug::on_iteration_complete( | ||
const LinOp* solver, const size_type& num_iterations, const LinOp* residual, | ||
const LinOp* solution, const LinOp* residual_norm, | ||
const LinOp* implicit_sq_residual_norm) const | ||
{ | ||
on_iteration_complete(solver, nullptr, solution, num_iterations, residual, | ||
residual_norm, implicit_sq_residual_norm, nullptr, | ||
false); | ||
} | ||
|
||
|
||
SolverDebug::SolverDebug(std::ostream& stream, int precision, int column_width) | ||
: output_{&stream}, precision_{precision}, column_width_{column_width} | ||
{} | ||
|
||
|
||
std::shared_ptr<SolverDebug> SolverDebug::create(std::ostream& output, | ||
int precision, | ||
int column_width) | ||
{ | ||
return std::shared_ptr<SolverDebug>{ | ||
new SolverDebug{output, precision, column_width}}; | ||
} | ||
|
||
|
||
} // namespace log | ||
} // namespace gko |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
// | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
|
||
#include <ginkgo/core/log/solver_debug.hpp> | ||
|
||
|
||
#include <gtest/gtest.h> | ||
|
||
|
||
#include <ginkgo/core/base/executor.hpp> | ||
#include <ginkgo/core/solver/cg.hpp> | ||
#include <ginkgo/core/stop/iteration.hpp> | ||
|
||
|
||
#include "core/test/utils.hpp" | ||
|
||
|
||
template <typename T> | ||
class SolverDebug : public ::testing::Test { | ||
public: | ||
using Dense = gko::matrix::Dense<T>; | ||
using Cg = gko::solver::Cg<T>; | ||
|
||
SolverDebug() : ref{gko::ReferenceExecutor::create()} | ||
{ | ||
mtx = gko::initialize<Dense>({T{1.0}}, ref); | ||
in = gko::initialize<Dense>({T{2.0}}, ref); | ||
out = mtx->clone(); | ||
solver = | ||
Cg::build() | ||
.with_criteria(gko::stop::Iteration::build().with_max_iters(1u)) | ||
.on(ref) | ||
->generate(mtx); | ||
} | ||
|
||
std::shared_ptr<gko::ReferenceExecutor> ref; | ||
std::shared_ptr<Dense> mtx; | ||
std::shared_ptr<Dense> in; | ||
std::unique_ptr<Dense> out; | ||
std::unique_ptr<Cg> solver; | ||
}; | ||
|
||
TYPED_TEST_SUITE(SolverDebug, gko::test::ValueTypes, TypenameNameGenerator); | ||
|
||
|
||
TYPED_TEST(SolverDebug, Works) | ||
{ | ||
using T = TypeParam; | ||
std::stringstream ref_ss; | ||
int default_column_width = 12; | ||
auto dynamic_type = gko::name_demangling::get_dynamic_type(*this->solver); | ||
ref_ss << dynamic_type << "::apply(" << this->in.get() << ',' | ||
<< this->out.get() << ") of dimensions " << this->solver->get_size() | ||
<< " and " << this->in->get_size()[1] << " rhs\n"; | ||
ref_ss << std::setw(default_column_width) << "Iteration" | ||
<< std::setw(default_column_width) << "alpha" | ||
<< std::setw(default_column_width) << "beta" | ||
<< std::setw(default_column_width) << "prev_rho" | ||
<< std::setw(default_column_width) << "rho" << '\n'; | ||
ref_ss << std::setw(default_column_width) << 0 | ||
<< std::setw(default_column_width) << T{0.0} | ||
<< std::setw(default_column_width) << T{0.0} | ||
<< std::setw(default_column_width) << T{1.0} | ||
<< std::setw(default_column_width) << T{1.0} << '\n' | ||
<< std::setw(default_column_width) << 1 | ||
<< std::setw(default_column_width) << T{0.0} | ||
<< std::setw(default_column_width) << T{1.0} | ||
<< std::setw(default_column_width) << T{0.0} | ||
<< std::setw(default_column_width) << T{1.0} << '\n'; | ||
std::stringstream ss; | ||
this->solver->add_logger(gko::log::SolverDebug::create(ss)); | ||
|
||
this->solver->apply(this->in, this->out); | ||
|
||
ASSERT_EQ(ss.str(), ref_ss.str()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
// | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
|
||
#ifndef GKO_PUBLIC_CORE_LOG_SOLVER_DEBUG_HPP_ | ||
#define GKO_PUBLIC_CORE_LOG_SOLVER_DEBUG_HPP_ | ||
|
||
|
||
#include <iosfwd> | ||
|
||
|
||
#include <ginkgo/config.hpp> | ||
#include <ginkgo/core/log/logger.hpp> | ||
|
||
|
||
namespace gko { | ||
namespace log { | ||
|
||
|
||
/** | ||
* This Logger prints the value of all scalar values stored internally by the | ||
* solver after each iteration. If the solver is applied to multiple right-hand | ||
* sides, only the first right-hand side gets printed. | ||
*/ | ||
class SolverDebug : public Logger { | ||
public: | ||
/* Internal solver events */ | ||
void on_linop_apply_started(const LinOp* A, const LinOp* b, | ||
const LinOp* x) const override; | ||
|
||
void on_iteration_complete( | ||
const LinOp* solver, const LinOp* right_hand_side, | ||
const LinOp* solution, const size_type& num_iterations, | ||
const LinOp* residual, const LinOp* residual_norm, | ||
const LinOp* implicit_sq_residual_norm, | ||
const array<stopping_status>* status, bool stopped) const override; | ||
|
||
GKO_DEPRECATED( | ||
"Please use the version with the additional stopping " | ||
"information.") | ||
void on_iteration_complete(const LinOp* solver, | ||
const size_type& num_iterations, | ||
const LinOp* residual, const LinOp* solution, | ||
const LinOp* residual_norm) const override; | ||
|
||
GKO_DEPRECATED( | ||
"Please use the version with the additional stopping " | ||
"information.") | ||
void on_iteration_complete( | ||
const LinOp* solver, const size_type& num_iterations, | ||
const LinOp* residual, const LinOp* solution, | ||
const LinOp* residual_norm, | ||
const LinOp* implicit_sq_residual_norm) const override; | ||
|
||
/** | ||
* Creates a logger printing the value for all scalar values in the solver | ||
* after each iteration. | ||
* | ||
* @param output the stream to write the output to. | ||
* @param precision the number of digits of precision to print | ||
* @param column_width the number of characters an output column is wide | ||
*/ | ||
static std::shared_ptr<SolverDebug> create(std::ostream& output, | ||
int precision = 6, | ||
int column_width = 12); | ||
|
||
private: | ||
SolverDebug(std::ostream& output, int precision, int column_width); | ||
|
||
std::ostream* output_; | ||
int precision_; | ||
int column_width_; | ||
}; | ||
|
||
|
||
} // namespace log | ||
} // namespace gko | ||
|
||
|
||
#endif // GKO_PUBLIC_CORE_LOG_SOLVER_DEBUG_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters