From 82712a3e9a2b54d2c6b80666ec82dd3dc865f3f8 Mon Sep 17 00:00:00 2001 From: Pratik Nayak <pratikvn@protonmail.com> Date: Wed, 25 Oct 2023 13:37:35 +0200 Subject: [PATCH] Review updates WIP Co-authored-by: Yu-Hsiang Tsai <yhmtsai@gmail.com> Co-authored-by: Marcel Koch <marcel.koch@kit.edu> --- core/base/batch_struct.hpp | 2 +- core/log/logger.cpp | 2 + core/matrix/batch_struct.hpp | 6 +- core/solver/batch_bicgstab.cpp | 2 +- core/solver/batch_bicgstab_kernels.hpp | 2 +- core/solver/batch_dispatch.hpp | 80 +++++++++++------ core/test/solver/batch_bicgstab.cpp | 44 +++++----- core/test/utils/batch_helpers.hpp | 34 +++---- cuda/solver/batch_bicgstab_kernels.cu | 2 +- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 2 +- hip/solver/batch_bicgstab_kernels.hip.cpp | 2 +- include/ginkgo/core/log/batch_logger.hpp | 33 ++++--- include/ginkgo/core/log/logger.hpp | 16 +++- include/ginkgo/core/matrix/batch_identity.hpp | 15 +--- include/ginkgo/core/solver/batch_bicgstab.hpp | 12 +-- .../ginkgo/core/solver/batch_solver_base.hpp | 88 +++++++++++-------- include/ginkgo/core/stop/batch_stop_enum.hpp | 12 +++ omp/solver/batch_bicgstab_kernels.cpp | 9 +- reference/log/batch_logger.hpp | 8 +- reference/solver/batch_bicgstab_kernels.cpp | 2 +- .../solver/batch_bicgstab_kernels.hpp.inc | 11 +-- .../test/solver/batch_bicgstab_kernels.cpp | 6 +- test/solver/batch_bicgstab_kernels.cpp | 4 +- 23 files changed, 225 insertions(+), 169 deletions(-) diff --git a/core/base/batch_struct.hpp b/core/base/batch_struct.hpp index 71445550b87..d7be0837534 100644 --- a/core/base/batch_struct.hpp +++ b/core/base/batch_struct.hpp @@ -71,7 +71,7 @@ struct uniform_batch { int32 num_rows; int32 num_rhs; - size_type get_entry_storage() const + size_type get_storage_size() const { return num_rows * stride * sizeof(value_type); } diff --git a/core/log/logger.cpp b/core/log/logger.cpp index 4b21bfe9b74..3cccb66d34c 100644 --- a/core/log/logger.cpp +++ b/core/log/logger.cpp @@ -75,6 +75,8 @@ constexpr Logger::mask_type Logger::linop_factory_generate_completed_mask; constexpr Logger::mask_type Logger::criterion_check_started_mask; constexpr Logger::mask_type Logger::criterion_check_completed_mask; +constexpr Logger::mask_type Logger::batch_solver_completed_mask; + constexpr Logger::mask_type Logger::iteration_complete_mask; diff --git a/core/matrix/batch_struct.hpp b/core/matrix/batch_struct.hpp index 575c511d051..2e668757b99 100644 --- a/core/matrix/batch_struct.hpp +++ b/core/matrix/batch_struct.hpp @@ -56,7 +56,6 @@ struct batch_item { int32 stride; int32 num_rows; int32 num_cols; - int32 num_nnz = num_rows * stride; }; @@ -73,14 +72,13 @@ struct uniform_batch { int32 stride; int32 num_rows; int32 num_cols; - int32 num_nnz = num_rows * stride; inline size_type get_num_nnz() const { return static_cast<size_type>(stride * num_rows); } - inline size_type get_entry_storage() const + inline size_type get_storage_size() const { return get_num_nnz() * sizeof(value_type); } @@ -132,7 +130,7 @@ struct uniform_batch { return static_cast<size_type>(stride * num_stored_elems_per_row); } - inline size_type get_entry_storage() const + inline size_type get_storage_size() const { return get_num_nnz() * sizeof(value_type); } diff --git a/core/solver/batch_bicgstab.cpp b/core/solver/batch_bicgstab.cpp index 41bc91d72dd..03ee9b9888e 100644 --- a/core/solver/batch_bicgstab.cpp +++ b/core/solver/batch_bicgstab.cpp @@ -56,7 +56,7 @@ GKO_REGISTER_OPERATION(apply, batch_bicgstab::apply); template <typename ValueType> void Bicgstab<ValueType>::solver_apply( const MultiVector<ValueType>* b, MultiVector<ValueType>* x, - log::BatchLogData<remove_complex<ValueType>>* log_data) const + log::detail::BatchLogData<remove_complex<ValueType>>* log_data) const { using MVec = MultiVector<ValueType>; const kernels::batch_bicgstab::BicgstabSettings<remove_complex<ValueType>> diff --git a/core/solver/batch_bicgstab_kernels.hpp b/core/solver/batch_bicgstab_kernels.hpp index 1c7b955c03f..0fd20ff32b8 100644 --- a/core/solver/batch_bicgstab_kernels.hpp +++ b/core/solver/batch_bicgstab_kernels.hpp @@ -100,7 +100,7 @@ inline int local_memory_requirement(const int num_rows, const int num_rhs) remove_complex<_type>>& options, \ const batch::BatchLinOp* a, const batch::BatchLinOp* preconditioner, \ const batch::MultiVector<_type>* b, batch::MultiVector<_type>* x, \ - gko::batch::log::BatchLogData<remove_complex<_type>>& logdata) + gko::batch::log::detail::BatchLogData<remove_complex<_type>>& logdata) #define GKO_DECLARE_ALL_AS_TEMPLATES \ diff --git a/core/solver/batch_dispatch.hpp b/core/solver/batch_dispatch.hpp index 449f54a7cba..f5029c4aaae 100644 --- a/core/solver/batch_dispatch.hpp +++ b/core/solver/batch_dispatch.hpp @@ -163,7 +163,7 @@ namespace solver { template <typename DValueType> -class DummyKernelCaller { +class KernelCallerInterface { public: template <typename BatchMatrixType, typename PrecType, typename StopType, typename LogType> @@ -174,30 +174,42 @@ class DummyKernelCaller { }; +namespace log { +namespace detail { +/** + * + * Types of batch loggers available. + */ +enum class BatchLogType { simple_convergence_completion }; + + +} // namespace detail +} // namespace log + + /** * Handles dispatching to the correct instantiation of a batched solver * depending on runtime parameters. * - * @tparam KernelCaller Class with an interface like DummyKernelCaller, + * @tparam ValueType The user-facing value type. + * @tparam KernelCaller Class with an interface like KernelCallerInterface, * that is responsible for finally calling the templated backend-specific * kernel. * @tparam SettingsType Structure type of options for the particular solver to * be used. - * @tparam ValueType The user-facing value type. */ -template <typename KernelCaller, typename SettingsType, typename ValueType> +template <typename ValueType, typename KernelCaller, typename SettingsType> class BatchSolverDispatch { public: using value_type = ValueType; using device_value_type = DeviceValueType<ValueType>; using real_type = remove_complex<value_type>; - BatchSolverDispatch(const KernelCaller& kernel_caller, - const SettingsType& settings, - const BatchLinOp* const matrix, - const BatchLinOp* const preconditioner, - const log::BatchLogType logger_type = - log::BatchLogType::simple_convergence_completion) + BatchSolverDispatch( + const KernelCaller& kernel_caller, const SettingsType& settings, + const BatchLinOp* const matrix, const BatchLinOp* const preconditioner, + const log::detail::BatchLogType logger_type = + log::detail::BatchLogType::simple_convergence_completion) : caller_{kernel_caller}, settings_{settings}, mat_{matrix}, @@ -250,9 +262,10 @@ class BatchSolverDispatch { const BatchMatrixType& amat, const multi_vector::uniform_batch<const device_value_type>& b_item, const multi_vector::uniform_batch<device_value_type>& x_item, - log::BatchLogData<real_type>& log_data) + batch::log::detail::BatchLogData<real_type>& log_data) { - if (logger_type_ == log::BatchLogType::simple_convergence_completion) { + if (logger_type_ == + log::detail::BatchLogType::simple_convergence_completion) { device::batch_log::SimpleFinalLogger<real_type> logger( log_data.res_norms.get_data(), log_data.iter_counts.get_data()); dispatch_on_preconditioner(logger, amat, b_item, x_item); @@ -261,19 +274,11 @@ class BatchSolverDispatch { } } - /** - * Solves a linear system from the given data and kernel caller. - * - * @note The correct backend-specific get_batch_struct function needs to be - * available in the current scope. - */ - void apply(const MultiVector<ValueType>* const b, - MultiVector<ValueType>* const x, - log::BatchLogData<real_type>& log_data) + void dispatch_on_matrix( + const multi_vector::uniform_batch<const device_value_type>& b_item, + const multi_vector::uniform_batch<device_value_type>& x_item, + batch::log::detail::BatchLogData<real_type>& log_data) { - const auto x_item = device::get_batch_struct(x); - const auto b_item = device::get_batch_struct(b); - if (auto batch_mat = dynamic_cast<const batch::matrix::Ell<ValueType, int32>*>( mat_)) { @@ -289,12 +294,28 @@ class BatchSolverDispatch { } } + /** + * Solves a linear system from the given data and kernel caller. + * + * @note The correct backend-specific get_batch_struct function needs to be + * available in the current scope. + */ + void apply(const MultiVector<ValueType>* const b, + MultiVector<ValueType>* const x, + batch::log::detail::BatchLogData<real_type>& log_data) + { + const auto x_item = device::get_batch_struct(x); + const auto b_item = device::get_batch_struct(b); + + dispatch_on_matrix(b_item, x_item, log_data); + } + private: const KernelCaller caller_; const SettingsType settings_; const BatchLinOp* mat_; const BatchLinOp* precond_; - const log::BatchLogType logger_type_; + const log::detail::BatchLogType logger_type_; }; @@ -302,13 +323,13 @@ class BatchSolverDispatch { * Convenient function to create a dispatcher. Infers most template arguments. */ template <typename ValueType, typename KernelCaller, typename SettingsType> -BatchSolverDispatch<KernelCaller, SettingsType, ValueType> create_dispatcher( +BatchSolverDispatch<ValueType, KernelCaller, SettingsType> create_dispatcher( const KernelCaller& kernel_caller, const SettingsType& settings, const BatchLinOp* const matrix, const BatchLinOp* const preconditioner, - const log::BatchLogType logger_type = - log::BatchLogType::simple_convergence_completion) + const log::detail::BatchLogType logger_type = + log::detail::BatchLogType::simple_convergence_completion) { - return BatchSolverDispatch<KernelCaller, SettingsType, ValueType>( + return BatchSolverDispatch<ValueType, KernelCaller, SettingsType>( kernel_caller, settings, matrix, preconditioner, logger_type); } @@ -317,4 +338,5 @@ BatchSolverDispatch<KernelCaller, SettingsType, ValueType> create_dispatcher( } // namespace batch } // namespace gko + #endif // GKO_CORE_SOLVER_BATCH_DISPATCH_HPP_ diff --git a/core/test/solver/batch_bicgstab.cpp b/core/test/solver/batch_bicgstab.cpp index ccbb924f1bd..4cf55f871b6 100644 --- a/core/test/solver/batch_bicgstab.cpp +++ b/core/test/solver/batch_bicgstab.cpp @@ -61,7 +61,7 @@ class BatchBicgstab : public ::testing::Test { BatchBicgstab() : exec(gko::ReferenceExecutor::create()), mtx(gko::test::generate_3pt_stencil_batch_matrix<Mtx>( - this->exec->get_master(), nrows, nbatch)), + this->exec->get_master(), num_rows, num_batch_items)), solver_factory(Solver::build() .with_default_max_iterations(def_max_iters) .with_default_tolerance(def_abs_res_tol) @@ -71,8 +71,8 @@ class BatchBicgstab : public ::testing::Test { {} std::shared_ptr<const gko::Executor> exec; - const gko::size_type nbatch = 3; - const int nrows = 5; + const gko::size_type num_batch_items = 3; + const int num_rows = 5; std::shared_ptr<Mtx> mtx; std::unique_ptr<typename Solver::Factory> solver_factory; const int def_max_iters = 100; @@ -94,12 +94,10 @@ TYPED_TEST(BatchBicgstab, FactoryKnowsItsExecutor) TYPED_TEST(BatchBicgstab, FactoryCreatesCorrectSolver) { using Solver = typename TestFixture::Solver; - for (size_t i = 0; i < this->nbatch; i++) { - ASSERT_EQ(this->solver->get_common_size(), - gko::dim<2>(this->nrows, this->nrows)); - } + ASSERT_EQ(this->solver->get_common_size(), + gko::dim<2>(this->num_rows, this->num_rows)); - auto solver = static_cast<Solver*>(this->solver.get()); + auto solver = gko::as<Solver>(this->solver.get()); ASSERT_NE(solver->get_system_matrix(), nullptr); ASSERT_EQ(solver->get_system_matrix(), this->mtx); @@ -114,10 +112,11 @@ TYPED_TEST(BatchBicgstab, CanBeCopied) copy->copy_from(this->solver.get()); - ASSERT_EQ(copy->get_common_size(), gko::dim<2>(this->nrows, this->nrows)); - ASSERT_EQ(copy->get_num_batch_items(), this->nbatch); - auto copy_mtx = static_cast<Solver*>(copy.get())->get_system_matrix(); - const auto copy_batch_mtx = static_cast<const Mtx*>(copy_mtx.get()); + ASSERT_EQ(copy->get_common_size(), + gko::dim<2>(this->num_rows, this->num_rows)); + ASSERT_EQ(copy->get_num_batch_items(), this->num_batch_items); + auto copy_mtx = gko::as<Solver>(copy.get())->get_system_matrix(); + const auto copy_batch_mtx = gko::as<const Mtx>(copy_mtx.get()); GKO_ASSERT_BATCH_MTX_NEAR(this->mtx.get(), copy_batch_mtx, 0.0); } @@ -130,10 +129,11 @@ TYPED_TEST(BatchBicgstab, CanBeMoved) copy->move_from(this->solver); - ASSERT_EQ(copy->get_common_size(), gko::dim<2>(this->nrows, this->nrows)); - ASSERT_EQ(copy->get_num_batch_items(), this->nbatch); - auto copy_mtx = static_cast<Solver*>(copy.get())->get_system_matrix(); - const auto copy_batch_mtx = static_cast<const Mtx*>(copy_mtx.get()); + ASSERT_EQ(copy->get_common_size(), + gko::dim<2>(this->num_rows, this->num_rows)); + ASSERT_EQ(copy->get_num_batch_items(), this->num_batch_items); + auto copy_mtx = gko::as<Solver>(copy.get())->get_system_matrix(); + const auto copy_batch_mtx = gko::as<const Mtx>(copy_mtx.get()); GKO_ASSERT_BATCH_MTX_NEAR(this->mtx.get(), copy_batch_mtx, 0.0); } @@ -145,10 +145,11 @@ TYPED_TEST(BatchBicgstab, CanBeCloned) auto clone = this->solver->clone(); - ASSERT_EQ(clone->get_common_size(), gko::dim<2>(this->nrows, this->nrows)); - ASSERT_EQ(clone->get_num_batch_items(), this->nbatch); - auto clone_mtx = static_cast<Solver*>(clone.get())->get_system_matrix(); - const auto clone_batch_mtx = static_cast<const Mtx*>(clone_mtx.get()); + ASSERT_EQ(clone->get_common_size(), + gko::dim<2>(this->num_rows, this->num_rows)); + ASSERT_EQ(clone->get_num_batch_items(), this->num_batch_items); + auto clone_mtx = gko::as<Solver>(clone.get())->get_system_matrix(); + const auto clone_batch_mtx = gko::as<const Mtx>(clone_mtx.get()); GKO_ASSERT_BATCH_MTX_NEAR(this->mtx.get(), clone_batch_mtx, 0.0); } @@ -160,8 +161,7 @@ TYPED_TEST(BatchBicgstab, CanBeCleared) this->solver->clear(); ASSERT_EQ(this->solver->get_num_batch_items(), 0); - auto solver_mtx = - static_cast<Solver*>(this->solver.get())->get_system_matrix(); + auto solver_mtx = gko::as<Solver>(this->solver.get())->get_system_matrix(); ASSERT_EQ(solver_mtx, nullptr); } diff --git a/core/test/utils/batch_helpers.hpp b/core/test/utils/batch_helpers.hpp index 7a874677c86..abdc3776603 100644 --- a/core/test/utils/batch_helpers.hpp +++ b/core/test/utils/batch_helpers.hpp @@ -199,7 +199,6 @@ compute_residual_norms( using real_vec = batch::MultiVector<remove_complex<value_type>>; auto exec = mtx->get_executor(); auto num_batch_items = x->get_num_batch_items(); - auto num_rows = x->get_common_size()[0]; auto num_rhs = x->get_common_size()[1]; const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs)); @@ -221,7 +220,13 @@ struct Result { std::shared_ptr<multi_vec> x; std::shared_ptr<real_vec> res_norm; - std::unique_ptr<gko::batch::log::BatchLogData<remove_complex<ValueType>>> +}; + + +template <typename ValueType> +struct ResultWithLogData : public Result<ValueType> { + std::unique_ptr< + gko::batch::log::detail::BatchLogData<remove_complex<ValueType>>> log_data; }; @@ -255,9 +260,9 @@ Result<typename MatrixType::value_type> solve_linear_system( } -template <typename MatrixType, typename SolveFunction, typename Settings> -Result<typename MatrixType::value_type> solve_linear_system( - std::shared_ptr<const Executor> exec, SolveFunction solve_function, +template <typename MatrixType, typename SolveLambda, typename Settings> +ResultWithLogData<typename MatrixType::value_type> solve_linear_system( + std::shared_ptr<const Executor> exec, SolveLambda solve_lambda, const Settings settings, const LinearSystem<MatrixType>& sys, std::shared_ptr<batch::BatchLinOpFactory> precond_factory = nullptr) { @@ -269,17 +274,15 @@ Result<typename MatrixType::value_type> solve_linear_system( const size_type num_batch_items = sys.matrix->get_num_batch_items(); const int num_rows = sys.matrix->get_common_size()[0]; const int num_rhs = sys.rhs->get_common_size()[1]; - const gko::batch_dim<2> vec_size(num_batch_items, - gko::dim<2>(num_rows, num_rhs)); const gko::batch_dim<2> norm_size(num_batch_items, gko::dim<2>(1, num_rhs)); - Result<value_type> result; - // Initialize r to the original unscaled b + ResultWithLogData<value_type> result; result.x = multi_vec::create_with_config_of(sys.rhs); result.x->fill(zero<value_type>()); - auto log_data = std::make_unique<batch::log::BatchLogData<real_type>>( - exec, num_batch_items); + auto log_data = + std::make_unique<batch::log::detail::BatchLogData<real_type>>( + exec, num_batch_items); std::unique_ptr<gko::batch::BatchLinOp> precond; if (precond_factory) { @@ -288,11 +291,12 @@ Result<typename MatrixType::value_type> solve_linear_system( precond = nullptr; } - solve_function(settings, precond.get(), sys.matrix.get(), sys.rhs.get(), - result.x.get(), *log_data.get()); + solve_lambda(settings, precond.get(), sys.matrix.get(), sys.rhs.get(), + result.x.get(), *log_data.get()); - result.log_data = std::make_unique<batch::log::BatchLogData<real_type>>( - exec->get_master()); + result.log_data = + std::make_unique<batch::log::detail::BatchLogData<real_type>>( + exec->get_master()); result.log_data->iter_counts = log_data->iter_counts; result.log_data->res_norms = log_data->res_norms; diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index fa00bb208af..4f36ed0022d 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -67,7 +67,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec, const batch::BatchLinOp* const precon, const batch::MultiVector<ValueType>* const b, batch::MultiVector<ValueType>* const x, - batch::log::BatchLogData<remove_complex<ValueType>>& logdata) + batch::log::detail::BatchLogData<remove_complex<ValueType>>& logdata) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 710c7a78c07..6f82aa8a779 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -64,7 +64,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec, const batch::BatchLinOp* const precon, const batch::MultiVector<ValueType>* const b, batch::MultiVector<ValueType>* const x, - batch::log::BatchLogData<remove_complex<ValueType>>& logdata) + batch::log::detail::BatchLogData<remove_complex<ValueType>>& logdata) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index 7a52149e21d..8b5abb6a562 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -68,7 +68,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec, const batch::BatchLinOp* const precon, const batch::MultiVector<ValueType>* const b, batch::MultiVector<ValueType>* const x, - batch::log::BatchLogData<remove_complex<ValueType>>& logdata) + batch::log::detail::BatchLogData<remove_complex<ValueType>>& logdata) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); diff --git a/include/ginkgo/core/log/batch_logger.hpp b/include/ginkgo/core/log/batch_logger.hpp index 122467893fd..7c1898dd5a0 100644 --- a/include/ginkgo/core/log/batch_logger.hpp +++ b/include/ginkgo/core/log/batch_logger.hpp @@ -50,12 +50,7 @@ namespace batch { * @ingroup log */ namespace log { - - -/** - * Types of batch loggers available. - */ -enum class BatchLogType { simple_convergence_completion }; +namespace detail { /** @@ -64,7 +59,7 @@ enum class BatchLogType { simple_convergence_completion }; * @note Supports only single rhs */ template <typename ValueType> -struct BatchLogData { +struct BatchLogData final { using real_type = remove_complex<ValueType>; BatchLogData(std::shared_ptr<const Executor> exec, @@ -103,6 +98,9 @@ struct BatchLogData { }; +} // namespace detail + + /** * Logs the final residuals and iteration counts for a batch solver. * @@ -127,6 +125,8 @@ class BatchConvergence : public gko::log::Logger { /** * Creates a convergence logger. This dynamically allocates the memory, * constructs the object and returns an std::unique_ptr to this object. + * TODO: See if the objects can be pre-allocated beforehand instead of being + * copied in the `on_<>` event * * @param exec the executor * @param enabled_events the events enabled for this logger. By default all @@ -135,11 +135,11 @@ class BatchConvergence : public gko::log::Logger { * @return an std::unique_ptr to the the constructed object */ static std::unique_ptr<BatchConvergence> create( - std::shared_ptr<const Executor> exec, - const mask_type& enabled_events = gko::log::Logger::all_events_mask) + const mask_type& enabled_events = + gko::log::Logger::batch_solver_completed_mask) { return std::unique_ptr<BatchConvergence>( - new BatchConvergence(exec, enabled_events)); + new BatchConvergence(enabled_events)); } /** @@ -159,17 +159,14 @@ class BatchConvergence : public gko::log::Logger { } protected: - explicit BatchConvergence( - std::shared_ptr<const Executor> exec, - const mask_type& enabled_events = gko::log::Logger::all_events_mask) - : gko::log::Logger(enabled_events), - iteration_count_(exec), - residual_norm_(exec) + explicit BatchConvergence(const mask_type& enabled_events = + gko::log::Logger::batch_solver_completed_mask) + : gko::log::Logger(enabled_events) {} private: - mutable array<int> iteration_count_; - mutable array<real_type> residual_norm_; + mutable array<int> iteration_count_{}; + mutable array<real_type> residual_norm_{}; }; diff --git a/include/ginkgo/core/log/logger.hpp b/include/ginkgo/core/log/logger.hpp index c16e7efbf0d..5f6d0739012 100644 --- a/include/ginkgo/core/log/logger.hpp +++ b/include/ginkgo/core/log/logger.hpp @@ -609,16 +609,30 @@ public: \ std::enable_if_t<Event == 26 && (26 < event_count_max)> on( Params&&... params) const { - if (enabled_events_ & (mask_type{1} << 26)) { + if (enabled_events_ & batch_solver_completed_mask) { this->on_batch_solver_completed(std::forward<Params>(params)...); } } protected: + /** + * Batch solver's event that records the iteration count and the residual + * norm. + * + * @param iters the array of iteration counts. + * @param residual_norms the array storing the residual norms. + */ virtual void on_batch_solver_completed( const array<int>& iters, const array<double>& residual_norms) const {} + /** + * Batch solver's event that records the iteration count and the residual + * norm. + * + * @param iters the array of iteration counts. + * @param residual_norms the array storing the residual norms. + */ virtual void on_batch_solver_completed( const array<int>& iters, const array<float>& residual_norms) const {} diff --git a/include/ginkgo/core/matrix/batch_identity.hpp b/include/ginkgo/core/matrix/batch_identity.hpp index 668fbcc1527..15b7623ac0f 100644 --- a/include/ginkgo/core/matrix/batch_identity.hpp +++ b/include/ginkgo/core/matrix/batch_identity.hpp @@ -48,18 +48,11 @@ namespace matrix { /** - * Identity is a batch matrix format which explicitly stores all values of - * the matrix in each of the batches. - * - * The values in each of the batches are stored in row-major format (values - * belonging to the same row appear consecutive in the memory). Optionally, rows - * can be padded for better memory access. + * The batch Identity matrix, which represents a batch of Identity matrices. * * @tparam ValueType precision of matrix elements * - * @note While this format is not very useful for storing sparse matrices, it - * is often suitable to store vectors, and sets of vectors. - * @ingroup batch_dense + * @ingroup batch_identity * @ingroup mat_formats * @ingroup BatchLinOp */ @@ -81,7 +74,7 @@ class Identity final : public EnableBatchLinOp<Identity<ValueType>>, /** * Apply the matrix to a multi-vector. Represents the matrix vector - * multiplication, x = A * b, where x and b are both multi-vectors. + * multiplication, x = I * b, where x and b are both multi-vectors. * * @param b the multi-vector to be applied to * @param x the output multi-vector @@ -91,7 +84,7 @@ class Identity final : public EnableBatchLinOp<Identity<ValueType>>, /** * Apply the matrix to a multi-vector with a linear combination of the given - * input vector. Represents the matrix vector multiplication, x = alpha * A + * input vector. Represents the matrix vector multiplication, x = alpha * I * * b + beta * x, where x and b are both multi-vectors. * * @param alpha the scalar to scale the matrix-vector product with diff --git a/include/ginkgo/core/solver/batch_bicgstab.hpp b/include/ginkgo/core/solver/batch_bicgstab.hpp index 32a0154f602..29b65bc225a 100644 --- a/include/ginkgo/core/solver/batch_bicgstab.hpp +++ b/include/ginkgo/core/solver/batch_bicgstab.hpp @@ -94,15 +94,15 @@ class Bicgstab final explicit Bicgstab(const Factory* factory, std::shared_ptr<const BatchLinOp> system_matrix) - : EnableBatchSolver<Bicgstab>( - factory->get_executor(), std::move(system_matrix), - detail::extract_common_batch_params(factory->get_parameters())), + : EnableBatchSolver<Bicgstab>(factory->get_executor(), + std::move(system_matrix), + factory->get_parameters()), parameters_{factory->get_parameters()} {} - void solver_apply(const MultiVector<ValueType>* b, - MultiVector<ValueType>* x, - log::BatchLogData<real_type>* log_data) const override; + void solver_apply( + const MultiVector<ValueType>* b, MultiVector<ValueType>* x, + log::detail::BatchLogData<real_type>* log_data) const override; }; diff --git a/include/ginkgo/core/solver/batch_solver_base.hpp b/include/ginkgo/core/solver/batch_solver_base.hpp index c0d5935fa30..2e877d8cb4e 100644 --- a/include/ginkgo/core/solver/batch_solver_base.hpp +++ b/include/ginkgo/core/solver/batch_solver_base.hpp @@ -89,7 +89,13 @@ class BatchSolver { * @param res_tol The residual tolerance to be used for subsequent * invocations of the solver. */ - void set_residual_tolerance(double res_tol) { residual_tol_ = res_tol; } + void set_residual_tolerance(double res_tol) + { + if (res_tol < 0) { + GKO_INVALID_STATE("Tolerance cannot be negative!"); + } + residual_tol_ = res_tol; + } /** * Get the maximum number of iterations set on the solver. @@ -106,19 +112,48 @@ class BatchSolver { */ void set_max_iterations(int max_iterations) { + if (max_iterations < 0) { + GKO_INVALID_STATE("Max iterations cannot be negative!"); + } max_iterations_ = max_iterations; } + /** + * Get the tolerance type. + * + * @return The tolerance type. + */ + ::gko::batch::stop::ToleranceType get_tolerance_type() const + { + return tol_type_; + } + + /** + * Set the type of tolerance check to use inside the solver + * + * @param tol_type The tolerance type. + */ + void set_tolerance_type(::gko::batch::stop::ToleranceType tol_type) + { + if (tol_type != ::gko::batch::stop::ToleranceType::absolute || + tol_type != ::gko::batch::stop::ToleranceType::relative) { + GKO_INVALID_STATE("Invalid tolerance type specified!"); + } + tol_type_ = tol_type; + } + protected: BatchSolver() {} BatchSolver(std::shared_ptr<const BatchLinOp> system_matrix, std::shared_ptr<const BatchLinOp> gen_preconditioner, - const double res_tol, const int max_iterations) + const double res_tol, const int max_iterations, + const ::gko::batch::stop::ToleranceType tol_type) : system_matrix_{std::move(system_matrix)}, preconditioner_{std::move(gen_preconditioner)}, residual_tol_{res_tol}, max_iterations_{max_iterations}, + tol_type_{tol_type}, workspace_{} {} @@ -126,32 +161,11 @@ class BatchSolver { std::shared_ptr<const BatchLinOp> preconditioner_{}; double residual_tol_{}; int max_iterations_{}; + ::gko::batch::stop::ToleranceType tol_type_{}; mutable array<unsigned char> workspace_{}; }; -namespace detail { - - -struct common_batch_params { - std::shared_ptr<const BatchLinOpFactory> prec_factory; - std::shared_ptr<const BatchLinOp> generated_prec; - double residual_tolerance; - int max_iterations; -}; - - -template <typename ParamsType> -common_batch_params extract_common_batch_params(ParamsType& params) -{ - return {params.preconditioner, params.generated_preconditioner, - params.default_tolerance, params.default_max_iterations}; -} - - -} // namespace detail - - /** * The parameter type shared between all preconditioned iterative solvers, * excluding the parameters available in iterative_solver_factory_parameters. @@ -301,11 +315,12 @@ class EnableBatchSolver : EnableBatchLinOp<ConcreteSolver, PolymorphicBase>(std::move(exec)) {} + template <typename FactoryParameters> explicit EnableBatchSolver(std::shared_ptr<const Executor> exec, std::shared_ptr<const BatchLinOp> system_matrix, - detail::common_batch_params common_params) - : BatchSolver(system_matrix, nullptr, common_params.residual_tolerance, - common_params.max_iterations), + const FactoryParameters& params) + : BatchSolver(system_matrix, nullptr, params.default_tolerance, + params.default_max_iterations, params.tolerance_type), EnableBatchLinOp<ConcreteSolver, PolymorphicBase>( exec, gko::transpose(system_matrix->get_size())) { @@ -315,13 +330,12 @@ class EnableBatchSolver using Identity = matrix::Identity<value_type>; using real_type = remove_complex<value_type>; - if (common_params.generated_prec) { - GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(common_params.generated_prec, + if (params.generated_preconditioner) { + GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(params.generated_preconditioner, this); - preconditioner_ = std::move(common_params.generated_prec); - } else if (common_params.prec_factory) { - preconditioner_ = - common_params.prec_factory->generate(system_matrix_); + preconditioner_ = std::move(params.generated_preconditioner); + } else if (params.preconditioner) { + preconditioner_ = params.preconditioner->generate(system_matrix_); } else { auto id = Identity::create(exec, system_matrix->get_size()); preconditioner_ = std::move(id); @@ -341,7 +355,7 @@ class EnableBatchSolver if (b->get_common_size()[1] > 1) { GKO_NOT_IMPLEMENTED; } - auto log_data_ = std::make_unique<log::BatchLogData<real_type>>( + auto log_data_ = std::make_unique<log::detail::BatchLogData<real_type>>( exec, b->get_num_batch_items(), workspace_); this->solver_apply(b, x, log_data_.get()); @@ -361,9 +375,9 @@ class EnableBatchSolver x->add_scaled(alpha, x_clone.get()); } - virtual void solver_apply(const MultiVector<ValueType>* b, - MultiVector<ValueType>* x, - log::BatchLogData<real_type>* info) const = 0; + virtual void solver_apply( + const MultiVector<ValueType>* b, MultiVector<ValueType>* x, + log::detail::BatchLogData<real_type>* info) const = 0; }; diff --git a/include/ginkgo/core/stop/batch_stop_enum.hpp b/include/ginkgo/core/stop/batch_stop_enum.hpp index d960e384d24..3199392cf3e 100644 --- a/include/ginkgo/core/stop/batch_stop_enum.hpp +++ b/include/ginkgo/core/stop/batch_stop_enum.hpp @@ -39,6 +39,18 @@ namespace batch { namespace stop { +/** + * This enum provides two types of options for the convergence of an iterative + * solver. + * + * `absolute` tolerance implies that the convergence criteria check is + * against the computed residual ($||r|| <= \tau$, where $||r||$ may be implicit + * or explicit depending on the solver). + * + * With the `relative` tolerance type, the solver + * convergence criteria checks against the relative residual norm + * ($\frac{||r||}{||b||} < \tau$, where $||b||$$ is the L2 norm of the rhs). + */ enum class ToleranceType { absolute, relative }; diff --git a/omp/solver/batch_bicgstab_kernels.cpp b/omp/solver/batch_bicgstab_kernels.cpp index 207ae042a4c..822c8820551 100644 --- a/omp/solver/batch_bicgstab_kernels.cpp +++ b/omp/solver/batch_bicgstab_kernels.cpp @@ -100,12 +100,11 @@ class KernelCaller { // TODO: Align to cache line boundary // TODO: Allocate and free once per thread rather than once per // work-item. - const auto local_space = - static_cast<unsigned char*>(malloc(local_size_bytes)); + auto local_space = array<unsigned char>(exec_, local_size_bytes); batch_entry_bicgstab_impl<StopType, PrecondType, LogType, BatchMatrixType, ValueType>( - settings_, logger, precond, mat, b, x, batch_id, local_space); - free(local_space); + settings_, logger, precond, mat, b, x, batch_id, + local_space.get_data()); } } @@ -122,7 +121,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec, const batch::BatchLinOp* const precond, const batch::MultiVector<ValueType>* const b, batch::MultiVector<ValueType>* const x, - batch::log::BatchLogData<remove_complex<ValueType>>& logdata) + batch::log::detail::BatchLogData<remove_complex<ValueType>>& logdata) { auto dispatcher = batch::solver::create_dispatcher<ValueType>( KernelCaller<ValueType>(exec, settings), settings, mat, precond); diff --git a/reference/log/batch_logger.hpp b/reference/log/batch_logger.hpp index e9dadb56ddc..0b1be52e1f4 100644 --- a/reference/log/batch_logger.hpp +++ b/reference/log/batch_logger.hpp @@ -44,7 +44,7 @@ namespace batch_log { /** - * Logs the final residual and iteration count for a batch solver. + * Logs the final residual norm and iteration count for a batch solver. * * @note Supports only a single RHS per batch item. */ @@ -66,11 +66,11 @@ class SimpleFinalLogger final { {} /** - * Logs the iteration count and residual norm. + * Logs the final iteration count and the final residual norm. * * @param batch_idx The index of linear system in the batch to log. - * @param iter The current iteration count (0-based). - * @param res_norm Norm of current residual + * @param iter The final iteration count (0-based). + * @param res_norm Norm of final residual norm */ void log_iteration(const size_type batch_idx, const int iter, const real_type res_norm) diff --git a/reference/solver/batch_bicgstab_kernels.cpp b/reference/solver/batch_bicgstab_kernels.cpp index 5b5d80794ad..b35b28c2cbf 100644 --- a/reference/solver/batch_bicgstab_kernels.cpp +++ b/reference/solver/batch_bicgstab_kernels.cpp @@ -119,7 +119,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec, const batch::BatchLinOp* const precon, const batch::MultiVector<ValueType>* const b, batch::MultiVector<ValueType>* const x, - batch::log::BatchLogData<remove_complex<ValueType>>& log_data) + batch::log::detail::BatchLogData<remove_complex<ValueType>>& log_data) { auto dispatcher = batch::solver::create_dispatcher<ValueType>( KernelCaller<ValueType>(exec, settings), settings, mat, precon); diff --git a/reference/solver/batch_bicgstab_kernels.hpp.inc b/reference/solver/batch_bicgstab_kernels.hpp.inc index 0bf38890fe2..0a281b34d49 100644 --- a/reference/solver/batch_bicgstab_kernels.hpp.inc +++ b/reference/solver/batch_bicgstab_kernels.hpp.inc @@ -274,6 +274,8 @@ inline void batch_entry_bicgstab_impl( for (iter = 0; iter < settings.max_iterations; iter++) { if (stop.check_converged(res_norms_entry.values)) { + logger.log_iteration(batch_item_id, iter, + res_norms_entry.values[0]); break; } @@ -313,13 +315,12 @@ inline void batch_entry_bicgstab_impl( res_norms_entry); if (stop.check_converged(res_norms_entry.values)) { - // update x for the systems (rhs) which converge at this point... x - // = x + alpha*p_hat - // note bits could change from 0 to 1, not the other way round, so - // we can use xor to get info about recent convergence... - // const uint32 converged_recent = converged_prev ^ converged; + // update x for the systems + // x = x + alpha * p_hat update_x_middle(gko::batch::to_const(alpha_entry), gko::batch::to_const(p_hat_entry), x_entry); + logger.log_iteration(batch_item_id, iter, + res_norms_entry.values[0]); break; } diff --git a/reference/test/solver/batch_bicgstab_kernels.cpp b/reference/test/solver/batch_bicgstab_kernels.cpp index 839f3c6961d..93d34befe91 100644 --- a/reference/test/solver/batch_bicgstab_kernels.cpp +++ b/reference/test/solver/batch_bicgstab_kernels.cpp @@ -64,7 +64,7 @@ class BatchBicgstab : public ::testing::Test { using MVec = gko::batch::MultiVector<value_type>; using RealMVec = gko::batch::MultiVector<real_type>; using Settings = gko::kernels::batch_bicgstab::BicgstabSettings<real_type>; - using LogData = gko::batch::log::BatchLogData<real_type>; + using LogData = gko::batch::log::detail::BatchLogData<real_type>; using LinSys = gko::test::LinearSystem<Mtx>; BatchBicgstab() @@ -181,7 +181,7 @@ TYPED_TEST(BatchBicgstab, StencilSystemLoggerLogsIterations) auto res = gko::test::solve_linear_system( this->exec, this->solve_lambda, solver_settings, this->linear_system); - const int* const iter_array = res.log_data->iter_counts.get_const_data(); + auto iter_array = res.log_data->iter_counts.get_const_data(); for (size_t i = 0; i < this->num_batch_items; i++) { ASSERT_EQ(iter_array[i], ref_iters); } @@ -239,7 +239,7 @@ TYPED_TEST(BatchBicgstab, ApplyLogsResAndIters) const int num_rows = 13; const size_t num_batch_items = 5; const int num_rhs = 1; - std::shared_ptr<Logger> logger = Logger::create(this->exec); + std::shared_ptr<Logger> logger = Logger::create(); auto linear_system = gko::test::generate_3pt_stencil_batch_problem<Mtx>( this->exec, num_batch_items, num_rows, num_rhs); auto solver = gko::share(solver_factory->generate(linear_system.matrix)); diff --git a/test/solver/batch_bicgstab_kernels.cpp b/test/solver/batch_bicgstab_kernels.cpp index e29d20cad83..f96e2c0948c 100644 --- a/test/solver/batch_bicgstab_kernels.cpp +++ b/test/solver/batch_bicgstab_kernels.cpp @@ -63,7 +63,7 @@ class BatchBicgstab : public CommonTestFixture { using MVec = gko::batch::MultiVector<value_type>; using RealMVec = gko::batch::MultiVector<real_type>; using Settings = gko::kernels::batch_bicgstab::BicgstabSettings<real_type>; - using LogData = gko::batch::log::BatchLogData<real_type>; + using LogData = gko::batch::log::detail::BatchLogData<real_type>; using Logger = gko::batch::log::BatchConvergence<real_type>; BatchBicgstab() {} @@ -207,7 +207,7 @@ TEST_F(BatchBicgstab, CanSolveLargeHpdSystem) .with_default_tolerance(tol) .with_tolerance_type(gko::batch::stop::ToleranceType::absolute) .on(exec); - std::shared_ptr<Logger> logger = Logger::create(exec); + std::shared_ptr<Logger> logger = Logger::create(); auto linear_system = gko::test::generate_diag_dominant_batch_problem<Mtx>( exec, num_batch_items, num_rows, num_rhs, true); auto solver = gko::share(solver_factory->generate(linear_system.matrix));