Skip to content

Commit

Permalink
Fix apply and more ref tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Oct 22, 2023
1 parent cc0d037 commit f752d83
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 60 deletions.
96 changes: 49 additions & 47 deletions core/test/utils/batch_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ template <typename MatrixType>
std::unique_ptr<
batch::MultiVector<remove_complex<typename MatrixType::value_type>>>
compute_residual_norms(
const MatrixType* const mtx,
const batch::MultiVector<typename MatrixType::value_type>* const b,
const batch::MultiVector<typename MatrixType::value_type>* const x)
const MatrixType* mtx,
const batch::MultiVector<typename MatrixType::value_type>* b,
const batch::MultiVector<typename MatrixType::value_type>* x)
{
using value_type = typename MatrixType::value_type;
using multi_vec = batch::MultiVector<value_type>;
Expand Down Expand Up @@ -249,7 +249,7 @@ Result<typename MatrixType::value_type> solve_linear_system(

solver->apply(sys.rhs, result.x);
result.res_norm =
compute_residual_norms(sys.matrix.get(), result.x.get(), sys.rhs.get());
compute_residual_norms(sys.matrix.get(), sys.rhs.get(), result.x.get());

return std::move(result);
}
Expand Down Expand Up @@ -277,10 +277,11 @@ Result<typename MatrixType::value_type> solve_linear_system(
// Initialize r to the original unscaled b
result.x = sys.rhs->clone();

result.logdata.res_norms =
gko::batch::log::BatchLogData<double> logdata;
logdata.res_norms =
gko::batch::MultiVector<double>::create(exec, norm_size);
result.logdata.iter_counts.set_executor(exec);
result.logdata.iter_counts.resize_and_reset(num_rhs * num_batch_items);
logdata.iter_counts.set_executor(exec);
logdata.iter_counts.resize_and_reset(num_rhs * num_batch_items);

std::unique_ptr<gko::batch::BatchLinOp> precond;
if (precond_factory) {
Expand All @@ -290,50 +291,51 @@ Result<typename MatrixType::value_type> solve_linear_system(
}

solve_function(settings, precond.get(), sys.matrix.get(), sys.rhs.get(),
result.x.get(), result.logdata);
result.x.get(), logdata);

result.logdata.res_norms =
gko::batch::MultiVector<double>::create(exec->get_master(), norm_size);
result.logdata.iter_counts.set_executor(exec->get_master());
result.logdata.iter_counts.resize_and_reset(num_rhs * num_batch_items);
result.logdata.res_norms->copy_from(logdata.res_norms.get());
result.logdata.iter_counts = logdata.iter_counts;

result.res_norm =
compute_residual_norms(sys.matrix.get(), result.x.get(), sys.rhs.get());
compute_residual_norms(sys.matrix.get(), sys.rhs.get(), result.x.get());

return std::move(result);
}


template <typename ValueType>
struct BatchSystem {
using vec_type = batch::MultiVector<ValueType>;
std::unique_ptr<batch::BatchLinOp> A;
std::unique_ptr<vec_type> b;
};


template <typename MatrixType, typename... MatrixArgs>
BatchSystem<typename MatrixType::value_type>
generate_diag_dominant_batch_system(std::shared_ptr<const gko::Executor> exec,
const size_type num_batch_items,
const int num_rows, const int num_rhs,
const bool is_hermitian,
MatrixArgs&&... args)
LinearSystem<MatrixType> generate_diag_dominant_batch_problem(
std::shared_ptr<const gko::Executor> exec, const size_type num_batch_items,
const int num_rows, const int num_rhs, const bool is_hermitian,
MatrixArgs&&... args)
{
using value_type = typename MatrixType::value_type;
using index_type = typename MatrixType::index_type;
using unbatch_type = typename MatrixType::unbatch_type;
using real_type = remove_complex<value_type>;
using unbatch_type = typename MatrixType::unbatch_type;
using multi_vec = batch::MultiVector<value_type>;
using real_vec = batch::MultiVector<real_type>;
const int num_cols = num_rows;
gko::matrix_data<value_type, index_type> data{
gko::dim<2>{static_cast<size_type>(num_rows),
static_cast<size_type>(num_cols)},
{}};
auto engine = std::default_random_engine(42);
auto rand_diag_dist = std::normal_distribution<value_type>(4.0, 12.0);
auto rand_diag_dist = std::normal_distribution<real_type>(4.0, 12.0);
for (int row = 1; row < num_rows - 1; ++row) {
auto rand_nnz_dist = std::normal_distribution<index_type>(1, row + 1);
auto k = detail::get_rand_value<index_type>(rand_nnz_dist, engine);
std::uniform_int_distribution<index_type> rand_nnz_dist{1, row + 1};
const auto k = rand_nnz_dist(engine);
data.nonzeros.emplace_back(row, k, value_type{-1.0});
data.nonzeros.emplace_back(row, row + 1, value_type{-1.0});
data.nonzeros.emplace_back(row - 1, row, value_type{-1.0});
data.nonzeros.emplace_back(
row, row, detail::get_rand_value(rand_diag_dist, engine));
row, row,
static_cast<value_type>(
detail::get_rand_value<real_type>(rand_diag_dist, engine)));
}
data.nonzeros.emplace_back(0, 0, value_type{2.0});
data.nonzeros.emplace_back(num_rows - 1, num_rows - 1, value_type{2.0});
Expand All @@ -356,41 +358,41 @@ generate_diag_dominant_batch_system(std::shared_ptr<const gko::Executor> exec,
exec->get_master(), soa_data.get_num_elems(),
soa_data.get_const_col_idxs())
.copy_to_array();
auto result = MatrixType::create(
exec, batch_dim<2>(num_batch_items, dim<2>(num_rows, num_cols)),
std::forward<MatrixArgs>(args)...);
auto rand_val_dist = std::normal_distribution<value_type>(-0.5, 0.5);

std::vector<gko::matrix_data<value_type, index_type>> batch_data(
num_batch_items);
batch_data.reserve(num_batch_items);
BatchSystem<value_type> sys;

auto rand_val_dist = std::normal_distribution<>(-0.5, 0.5);
for (size_type b = 1; b < num_batch_items; b++) {
auto rand_data = fill_random_matrix_data<value_type, index_type>(
num_rows, num_cols, row_idxs, col_idxs, rand_val_dist, engine);
if (is_hermitian) {
gko::utils::make_hpd(rand_data);
} else {
gko::utils::make_diag_dominant(rand_data);
}
gko::utils::make_diag_dominant(rand_data);
batch_data.emplace_back(rand_data);
}
sys.A = gko::give(gko::batch::read<value_type, index_type, MatrixType>(

LinearSystem<MatrixType> sys;
sys.matrix = gko::give(gko::batch::read<value_type, index_type, MatrixType>(
exec, batch_data, std::forward<MatrixArgs>(args)...));

std::vector<gko::matrix_data<value_type, index_type>> batch_rhs_data(
std::vector<gko::matrix_data<value_type, index_type>> batch_sol_data(
num_batch_items);
batch_rhs_data.reserve(num_batch_items);
batch_sol_data.reserve(num_batch_items);
for (size_type b = 0; b < num_batch_items; b++) {
auto rand_data = generate_random_matrix_data<value_type, index_type>(
num_rows, num_cols,
std::normal_distribution<index_type>(num_rhs, num_rhs),
std::uniform_int_distribution<index_type>(num_rhs, num_rhs),
rand_val_dist, engine);
batch_data.emplace_back(rand_data);
batch_sol_data.emplace_back(rand_data);
}
sys.b = gko::give(gko::batch::read<value_type, index_type,
BatchSystem<value_type>::vec_type>(
exec, batch_rhs_data));
sys.exact_sol = gko::give(
gko::batch::read<value_type, index_type,
typename LinearSystem<MatrixType>::multi_vec>(
exec, batch_sol_data));
sys.rhs = sys.exact_sol->clone();
sys.matrix->apply(sys.exact_sol, sys.rhs);
const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs));
sys.rhs_norm = real_vec::create(exec, norm_dim);
sys.rhs->compute_norm2(sys.rhs_norm.get());
return sys;
}

Expand Down
35 changes: 27 additions & 8 deletions include/ginkgo/core/solver/batch_solver_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,21 @@ class EnableBatchSolver
: public BatchSolver,
public EnableBatchLinOp<ConcreteSolver, PolymorphicBase> {
public:
ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> b,
ptr_param<MultiVector<ValueType>> x) const
const ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> b,
ptr_param<MultiVector<ValueType>> x) const
{
this->validate_application_parameters(b.get(), x.get());
auto exec = this->get_executor();
this->apply_impl(make_temporary_clone(exec, b).get(),
make_temporary_clone(exec, x).get());
return this;
return self();
}

ConcreteSolver* apply_impl(ptr_param<const MultiVector<ValueType>>* alpha,
ptr_param<const MultiVector<ValueType>>* b,
ptr_param<const MultiVector<ValueType>>* beta,
ptr_param<MultiVector<ValueType>>* x) const
const ConcreteSolver* apply_impl(
ptr_param<const MultiVector<ValueType>>* alpha,
ptr_param<const MultiVector<ValueType>>* b,
ptr_param<const MultiVector<ValueType>>* beta,
ptr_param<MultiVector<ValueType>>* x) const
{
this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
x.get());
Expand All @@ -178,10 +179,28 @@ class EnableBatchSolver
make_temporary_clone(exec, b).get(),
make_temporary_clone(exec, beta).get(),
make_temporary_clone(exec, x).get());
return this;
return self();
}

ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> b,
ptr_param<MultiVector<ValueType>> x)
{
static_cast<const ConcreteSolver*>(this)->apply(b, x);
return self();
}

ConcreteSolver* apply_impl(ptr_param<const MultiVector<ValueType>>* alpha,
ptr_param<const MultiVector<ValueType>>* b,
ptr_param<const MultiVector<ValueType>>* beta,
ptr_param<MultiVector<ValueType>>* x)
{
static_cast<const ConcreteSolver*>(this)->apply(alpha, b, beta, x);
return self();
}

protected:
GKO_ENABLE_SELF(ConcreteSolver);

explicit EnableBatchSolver(std::shared_ptr<const Executor> exec)
: EnableBatchLinOp<ConcreteSolver, PolymorphicBase>(std::move(exec))
{}
Expand Down
95 changes: 90 additions & 5 deletions reference/test/solver/batch_bicgstab_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/log/batch_logger.hpp>
#include <ginkgo/core/matrix/batch_dense.hpp>
#include <ginkgo/core/matrix/batch_ell.hpp>


#include "core/base/batch_utilities.hpp"
Expand All @@ -58,6 +60,7 @@ class BatchBicgstab : public ::testing::Test {
using real_type = gko::remove_complex<value_type>;
using solver_type = gko::batch::solver::Bicgstab<value_type>;
using Mtx = gko::batch::matrix::Dense<value_type>;
using EllMtx = gko::batch::matrix::Ell<value_type>;
using MVec = gko::batch::MultiVector<value_type>;
using RealMVec = gko::batch::MultiVector<real_type>;
using Settings = gko::kernels::batch_bicgstab::BicgstabOptions<real_type>;
Expand Down Expand Up @@ -111,7 +114,7 @@ TYPED_TEST(BatchBicgstab, SolvesStencilSystem)
}


TYPED_TEST(BatchBicgstab, StencilSystemLoggerIsCorrect)
TYPED_TEST(BatchBicgstab, StencilSystemLoggerLogsResidual)
{
using value_type = typename TestFixture::value_type;
using real_type = gko::remove_complex<value_type>;
Expand All @@ -125,9 +128,6 @@ TYPED_TEST(BatchBicgstab, StencilSystemLoggerIsCorrect)
const double* const res_log_array =
res.logdata.res_norms->get_const_values();
for (size_t i = 0; i < this->num_batch_items; i++) {
// test logger
GKO_ASSERT((iter_array[i] <= ref_iters + 1) &&
(iter_array[i] >= ref_iters - 1));
ASSERT_LE(res_log_array[i] / this->linear_system.rhs_norm->at(i, 0, 0),
this->solver_settings.residual_tol);
ASSERT_NEAR(res_log_array[i], res.res_norm->get_const_values()[i],
Expand All @@ -136,6 +136,25 @@ TYPED_TEST(BatchBicgstab, StencilSystemLoggerIsCorrect)
}


TYPED_TEST(BatchBicgstab, StencilSystemLoggerLogsIterations)
{
using value_type = typename TestFixture::value_type;
using Settings = typename TestFixture::Settings;
using real_type = gko::remove_complex<value_type>;
const int ref_iters = 5;
const Settings solver_settings{ref_iters, 0,
gko::batch::stop::ToleranceType::relative};

auto res = gko::test::solve_linear_system(
this->exec, this->solve_lambda, solver_settings, this->linear_system);

const int* const iter_array = res.logdata.iter_counts.get_const_data();
for (size_t i = 0; i < this->num_batch_items; i++) {
ASSERT_EQ(iter_array[i], ref_iters);
}
}


TYPED_TEST(BatchBicgstab, CanSolveDenseSystem)
{
using value_type = typename TestFixture::value_type;
Expand All @@ -160,7 +179,73 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseSystem)
auto res =
gko::test::solve_linear_system(this->exec, linear_system, solver);

GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, this->eps);
GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10);
for (size_t i = 0; i < num_batch_items; i++) {
ASSERT_LE(res.res_norm->get_const_values()[i] /
linear_system.rhs_norm->get_const_values()[i],
tol);
}
}


TYPED_TEST(BatchBicgstab, CanSolveEllSystem)
{
using value_type = typename TestFixture::value_type;
using real_type = gko::remove_complex<value_type>;
using Solver = typename TestFixture::solver_type;
using Mtx = typename TestFixture::EllMtx;
const real_type tol = 1e-5;
const int max_iters = 1000;
auto solver_factory =
Solver::build()
.with_default_max_iterations(max_iters)
.with_default_residual_tol(tol)
.with_tolerance_type(gko::batch::stop::ToleranceType::relative)
.on(this->exec);
const int num_rows = 13;
const size_t num_batch_items = 5;
const int num_rhs = 1;
auto linear_system = gko::test::generate_3pt_stencil_batch_problem<Mtx>(
this->exec, num_batch_items, num_rows, num_rhs, 3);
auto solver = gko::share(solver_factory->generate(linear_system.matrix));

auto res =
gko::test::solve_linear_system(this->exec, linear_system, solver);

GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10);
for (size_t i = 0; i < num_batch_items; i++) {
ASSERT_LE(res.res_norm->get_const_values()[i] /
linear_system.rhs_norm->get_const_values()[i],
tol);
}
}


TYPED_TEST(BatchBicgstab, CanSolveDenseHpdSystem)
{
using value_type = typename TestFixture::value_type;
using real_type = gko::remove_complex<value_type>;
using Solver = typename TestFixture::solver_type;
using Mtx = typename TestFixture::Mtx;
const real_type tol = 1e-5;
const int max_iters = 1000;
auto solver_factory =
Solver::build()
.with_default_max_iterations(max_iters)
.with_default_residual_tol(tol)
.with_tolerance_type(gko::batch::stop::ToleranceType::relative)
.on(this->exec);
const int num_rows = 65;
const gko::size_type num_batch_items = 5;
const int num_rhs = 1;
auto linear_system = gko::test::generate_diag_dominant_batch_problem<Mtx>(
this->exec, num_batch_items, num_rows, num_rhs, true);
auto solver = gko::share(solver_factory->generate(linear_system.matrix));

auto res =
gko::test::solve_linear_system(this->exec, linear_system, solver);

GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10);
for (size_t i = 0; i < num_batch_items; i++) {
ASSERT_LE(res.res_norm->get_const_values()[i] /
linear_system.rhs_norm->get_const_values()[i],
Expand Down

0 comments on commit f752d83

Please sign in to comment.