diff --git a/core/test/utils/batch_helpers.hpp b/core/test/utils/batch_helpers.hpp index 4f83b4bc2e5..4e379009a83 100644 --- a/core/test/utils/batch_helpers.hpp +++ b/core/test/utils/batch_helpers.hpp @@ -191,9 +191,9 @@ template std::unique_ptr< batch::MultiVector>> compute_residual_norms( - const MatrixType* const mtx, - const batch::MultiVector* const b, - const batch::MultiVector* const x) + const MatrixType* mtx, + const batch::MultiVector* b, + const batch::MultiVector* x) { using value_type = typename MatrixType::value_type; using multi_vec = batch::MultiVector; @@ -249,7 +249,7 @@ Result solve_linear_system( solver->apply(sys.rhs, result.x); result.res_norm = - compute_residual_norms(sys.matrix.get(), result.x.get(), sys.rhs.get()); + compute_residual_norms(sys.matrix.get(), sys.rhs.get(), result.x.get()); return std::move(result); } @@ -277,10 +277,11 @@ Result solve_linear_system( // Initialize r to the original unscaled b result.x = sys.rhs->clone(); - result.logdata.res_norms = + gko::batch::log::BatchLogData logdata; + logdata.res_norms = gko::batch::MultiVector::create(exec, norm_size); - result.logdata.iter_counts.set_executor(exec); - result.logdata.iter_counts.resize_and_reset(num_rhs * num_batch_items); + logdata.iter_counts.set_executor(exec); + logdata.iter_counts.resize_and_reset(num_rhs * num_batch_items); std::unique_ptr precond; if (precond_factory) { @@ -290,50 +291,51 @@ Result solve_linear_system( } solve_function(settings, precond.get(), sys.matrix.get(), sys.rhs.get(), - result.x.get(), result.logdata); + result.x.get(), logdata); + + result.logdata.res_norms = + gko::batch::MultiVector::create(exec->get_master(), norm_size); + result.logdata.iter_counts.set_executor(exec->get_master()); + result.logdata.iter_counts.resize_and_reset(num_rhs * num_batch_items); + result.logdata.res_norms->copy_from(logdata.res_norms.get()); + result.logdata.iter_counts = logdata.iter_counts; result.res_norm = - compute_residual_norms(sys.matrix.get(), result.x.get(), sys.rhs.get()); + compute_residual_norms(sys.matrix.get(), sys.rhs.get(), result.x.get()); return std::move(result); } -template -struct BatchSystem { - using vec_type = batch::MultiVector; - std::unique_ptr A; - std::unique_ptr b; -}; - - template -BatchSystem -generate_diag_dominant_batch_system(std::shared_ptr exec, - const size_type num_batch_items, - const int num_rows, const int num_rhs, - const bool is_hermitian, - MatrixArgs&&... args) +LinearSystem generate_diag_dominant_batch_problem( + std::shared_ptr exec, const size_type num_batch_items, + const int num_rows, const int num_rhs, const bool is_hermitian, + MatrixArgs&&... args) { using value_type = typename MatrixType::value_type; using index_type = typename MatrixType::index_type; - using unbatch_type = typename MatrixType::unbatch_type; using real_type = remove_complex; + using unbatch_type = typename MatrixType::unbatch_type; + using multi_vec = batch::MultiVector; + using real_vec = batch::MultiVector; const int num_cols = num_rows; gko::matrix_data data{ gko::dim<2>{static_cast(num_rows), static_cast(num_cols)}, {}}; auto engine = std::default_random_engine(42); - auto rand_diag_dist = std::normal_distribution(4.0, 12.0); + auto rand_diag_dist = std::normal_distribution(4.0, 12.0); for (int row = 1; row < num_rows - 1; ++row) { - auto rand_nnz_dist = std::normal_distribution(1, row + 1); - auto k = detail::get_rand_value(rand_nnz_dist, engine); + std::uniform_int_distribution rand_nnz_dist{1, row + 1}; + const auto k = rand_nnz_dist(engine); data.nonzeros.emplace_back(row, k, value_type{-1.0}); data.nonzeros.emplace_back(row, row + 1, value_type{-1.0}); data.nonzeros.emplace_back(row - 1, row, value_type{-1.0}); data.nonzeros.emplace_back( - row, row, detail::get_rand_value(rand_diag_dist, engine)); + row, row, + static_cast( + detail::get_rand_value(rand_diag_dist, engine))); } data.nonzeros.emplace_back(0, 0, value_type{2.0}); data.nonzeros.emplace_back(num_rows - 1, num_rows - 1, value_type{2.0}); @@ -356,41 +358,41 @@ generate_diag_dominant_batch_system(std::shared_ptr exec, exec->get_master(), soa_data.get_num_elems(), soa_data.get_const_col_idxs()) .copy_to_array(); - auto result = MatrixType::create( - exec, batch_dim<2>(num_batch_items, dim<2>(num_rows, num_cols)), - std::forward(args)...); - auto rand_val_dist = std::normal_distribution(-0.5, 0.5); + std::vector> batch_data( num_batch_items); batch_data.reserve(num_batch_items); - BatchSystem sys; - + auto rand_val_dist = std::normal_distribution<>(-0.5, 0.5); for (size_type b = 1; b < num_batch_items; b++) { auto rand_data = fill_random_matrix_data( num_rows, num_cols, row_idxs, col_idxs, rand_val_dist, engine); - if (is_hermitian) { - gko::utils::make_hpd(rand_data); - } else { - gko::utils::make_diag_dominant(rand_data); - } + gko::utils::make_diag_dominant(rand_data); batch_data.emplace_back(rand_data); } - sys.A = gko::give(gko::batch::read( + + LinearSystem sys; + sys.matrix = gko::give(gko::batch::read( exec, batch_data, std::forward(args)...)); - std::vector> batch_rhs_data( + std::vector> batch_sol_data( num_batch_items); - batch_rhs_data.reserve(num_batch_items); + batch_sol_data.reserve(num_batch_items); for (size_type b = 0; b < num_batch_items; b++) { auto rand_data = generate_random_matrix_data( num_rows, num_cols, - std::normal_distribution(num_rhs, num_rhs), + std::uniform_int_distribution(num_rhs, num_rhs), rand_val_dist, engine); - batch_data.emplace_back(rand_data); + batch_sol_data.emplace_back(rand_data); } - sys.b = gko::give(gko::batch::read::vec_type>( - exec, batch_rhs_data)); + sys.exact_sol = gko::give( + gko::batch::read::multi_vec>( + exec, batch_sol_data)); + sys.rhs = sys.exact_sol->clone(); + sys.matrix->apply(sys.exact_sol, sys.rhs); + const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs)); + sys.rhs_norm = real_vec::create(exec, norm_dim); + sys.rhs->compute_norm2(sys.rhs_norm.get()); return sys; } diff --git a/include/ginkgo/core/solver/batch_solver_base.hpp b/include/ginkgo/core/solver/batch_solver_base.hpp index ac24d421b80..c22d19420f4 100644 --- a/include/ginkgo/core/solver/batch_solver_base.hpp +++ b/include/ginkgo/core/solver/batch_solver_base.hpp @@ -156,20 +156,21 @@ class EnableBatchSolver : public BatchSolver, public EnableBatchLinOp { public: - ConcreteSolver* apply(ptr_param> b, - ptr_param> x) const + const ConcreteSolver* apply(ptr_param> b, + ptr_param> x) const { this->validate_application_parameters(b.get(), x.get()); auto exec = this->get_executor(); this->apply_impl(make_temporary_clone(exec, b).get(), make_temporary_clone(exec, x).get()); - return this; + return self(); } - ConcreteSolver* apply_impl(ptr_param>* alpha, - ptr_param>* b, - ptr_param>* beta, - ptr_param>* x) const + const ConcreteSolver* apply_impl( + ptr_param>* alpha, + ptr_param>* b, + ptr_param>* beta, + ptr_param>* x) const { this->validate_application_parameters(alpha.get(), b.get(), beta.get(), x.get()); @@ -178,10 +179,28 @@ class EnableBatchSolver make_temporary_clone(exec, b).get(), make_temporary_clone(exec, beta).get(), make_temporary_clone(exec, x).get()); - return this; + return self(); + } + + ConcreteSolver* apply(ptr_param> b, + ptr_param> x) + { + static_cast(this)->apply(b, x); + return self(); + } + + ConcreteSolver* apply_impl(ptr_param>* alpha, + ptr_param>* b, + ptr_param>* beta, + ptr_param>* x) + { + static_cast(this)->apply(alpha, b, beta, x); + return self(); } protected: + GKO_ENABLE_SELF(ConcreteSolver); + explicit EnableBatchSolver(std::shared_ptr exec) : EnableBatchLinOp(std::move(exec)) {} diff --git a/reference/test/solver/batch_bicgstab_kernels.cpp b/reference/test/solver/batch_bicgstab_kernels.cpp index 477a9e3f686..0445c8c09cf 100644 --- a/reference/test/solver/batch_bicgstab_kernels.cpp +++ b/reference/test/solver/batch_bicgstab_kernels.cpp @@ -42,6 +42,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include +#include +#include #include "core/base/batch_utilities.hpp" @@ -58,6 +60,7 @@ class BatchBicgstab : public ::testing::Test { using real_type = gko::remove_complex; using solver_type = gko::batch::solver::Bicgstab; using Mtx = gko::batch::matrix::Dense; + using EllMtx = gko::batch::matrix::Ell; using MVec = gko::batch::MultiVector; using RealMVec = gko::batch::MultiVector; using Settings = gko::kernels::batch_bicgstab::BicgstabOptions; @@ -111,7 +114,7 @@ TYPED_TEST(BatchBicgstab, SolvesStencilSystem) } -TYPED_TEST(BatchBicgstab, StencilSystemLoggerIsCorrect) +TYPED_TEST(BatchBicgstab, StencilSystemLoggerLogsResidual) { using value_type = typename TestFixture::value_type; using real_type = gko::remove_complex; @@ -125,9 +128,6 @@ TYPED_TEST(BatchBicgstab, StencilSystemLoggerIsCorrect) const double* const res_log_array = res.logdata.res_norms->get_const_values(); for (size_t i = 0; i < this->num_batch_items; i++) { - // test logger - GKO_ASSERT((iter_array[i] <= ref_iters + 1) && - (iter_array[i] >= ref_iters - 1)); ASSERT_LE(res_log_array[i] / this->linear_system.rhs_norm->at(i, 0, 0), this->solver_settings.residual_tol); ASSERT_NEAR(res_log_array[i], res.res_norm->get_const_values()[i], @@ -136,6 +136,25 @@ TYPED_TEST(BatchBicgstab, StencilSystemLoggerIsCorrect) } +TYPED_TEST(BatchBicgstab, StencilSystemLoggerLogsIterations) +{ + using value_type = typename TestFixture::value_type; + using Settings = typename TestFixture::Settings; + using real_type = gko::remove_complex; + const int ref_iters = 5; + const Settings solver_settings{ref_iters, 0, + gko::batch::stop::ToleranceType::relative}; + + auto res = gko::test::solve_linear_system( + this->exec, this->solve_lambda, solver_settings, this->linear_system); + + const int* const iter_array = res.logdata.iter_counts.get_const_data(); + for (size_t i = 0; i < this->num_batch_items; i++) { + ASSERT_EQ(iter_array[i], ref_iters); + } +} + + TYPED_TEST(BatchBicgstab, CanSolveDenseSystem) { using value_type = typename TestFixture::value_type; @@ -160,7 +179,73 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseSystem) auto res = gko::test::solve_linear_system(this->exec, linear_system, solver); - GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, this->eps); + GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); + for (size_t i = 0; i < num_batch_items; i++) { + ASSERT_LE(res.res_norm->get_const_values()[i] / + linear_system.rhs_norm->get_const_values()[i], + tol); + } +} + + +TYPED_TEST(BatchBicgstab, CanSolveEllSystem) +{ + using value_type = typename TestFixture::value_type; + using real_type = gko::remove_complex; + using Solver = typename TestFixture::solver_type; + using Mtx = typename TestFixture::EllMtx; + const real_type tol = 1e-5; + const int max_iters = 1000; + auto solver_factory = + Solver::build() + .with_default_max_iterations(max_iters) + .with_default_residual_tol(tol) + .with_tolerance_type(gko::batch::stop::ToleranceType::relative) + .on(this->exec); + const int num_rows = 13; + const size_t num_batch_items = 5; + const int num_rhs = 1; + auto linear_system = gko::test::generate_3pt_stencil_batch_problem( + this->exec, num_batch_items, num_rows, num_rhs, 3); + auto solver = gko::share(solver_factory->generate(linear_system.matrix)); + + auto res = + gko::test::solve_linear_system(this->exec, linear_system, solver); + + GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); + for (size_t i = 0; i < num_batch_items; i++) { + ASSERT_LE(res.res_norm->get_const_values()[i] / + linear_system.rhs_norm->get_const_values()[i], + tol); + } +} + + +TYPED_TEST(BatchBicgstab, CanSolveDenseHpdSystem) +{ + using value_type = typename TestFixture::value_type; + using real_type = gko::remove_complex; + using Solver = typename TestFixture::solver_type; + using Mtx = typename TestFixture::Mtx; + const real_type tol = 1e-5; + const int max_iters = 1000; + auto solver_factory = + Solver::build() + .with_default_max_iterations(max_iters) + .with_default_residual_tol(tol) + .with_tolerance_type(gko::batch::stop::ToleranceType::relative) + .on(this->exec); + const int num_rows = 65; + const gko::size_type num_batch_items = 5; + const int num_rhs = 1; + auto linear_system = gko::test::generate_diag_dominant_batch_problem( + this->exec, num_batch_items, num_rows, num_rhs, true); + auto solver = gko::share(solver_factory->generate(linear_system.matrix)); + + auto res = + gko::test::solve_linear_system(this->exec, linear_system, solver); + + GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); for (size_t i = 0; i < num_batch_items; i++) { ASSERT_LE(res.res_norm->get_const_values()[i] / linear_system.rhs_norm->get_const_values()[i],