From 82712a3e9a2b54d2c6b80666ec82dd3dc865f3f8 Mon Sep 17 00:00:00 2001
From: Pratik Nayak <pratikvn@protonmail.com>
Date: Wed, 25 Oct 2023 13:37:35 +0200
Subject: [PATCH] Review updates WIP

Co-authored-by: Yu-Hsiang Tsai <yhmtsai@gmail.com>
Co-authored-by: Marcel Koch <marcel.koch@kit.edu>
---
 core/base/batch_struct.hpp                    |  2 +-
 core/log/logger.cpp                           |  2 +
 core/matrix/batch_struct.hpp                  |  6 +-
 core/solver/batch_bicgstab.cpp                |  2 +-
 core/solver/batch_bicgstab_kernels.hpp        |  2 +-
 core/solver/batch_dispatch.hpp                | 80 +++++++++++------
 core/test/solver/batch_bicgstab.cpp           | 44 +++++-----
 core/test/utils/batch_helpers.hpp             | 34 +++----
 cuda/solver/batch_bicgstab_kernels.cu         |  2 +-
 dpcpp/solver/batch_bicgstab_kernels.dp.cpp    |  2 +-
 hip/solver/batch_bicgstab_kernels.hip.cpp     |  2 +-
 include/ginkgo/core/log/batch_logger.hpp      | 33 ++++---
 include/ginkgo/core/log/logger.hpp            | 16 +++-
 include/ginkgo/core/matrix/batch_identity.hpp | 15 +---
 include/ginkgo/core/solver/batch_bicgstab.hpp | 12 +--
 .../ginkgo/core/solver/batch_solver_base.hpp  | 88 +++++++++++--------
 include/ginkgo/core/stop/batch_stop_enum.hpp  | 12 +++
 omp/solver/batch_bicgstab_kernels.cpp         |  9 +-
 reference/log/batch_logger.hpp                |  8 +-
 reference/solver/batch_bicgstab_kernels.cpp   |  2 +-
 .../solver/batch_bicgstab_kernels.hpp.inc     | 11 +--
 .../test/solver/batch_bicgstab_kernels.cpp    |  6 +-
 test/solver/batch_bicgstab_kernels.cpp        |  4 +-
 23 files changed, 225 insertions(+), 169 deletions(-)

diff --git a/core/base/batch_struct.hpp b/core/base/batch_struct.hpp
index 71445550b87..d7be0837534 100644
--- a/core/base/batch_struct.hpp
+++ b/core/base/batch_struct.hpp
@@ -71,7 +71,7 @@ struct uniform_batch {
     int32 num_rows;
     int32 num_rhs;
 
-    size_type get_entry_storage() const
+    size_type get_storage_size() const
     {
         return num_rows * stride * sizeof(value_type);
     }
diff --git a/core/log/logger.cpp b/core/log/logger.cpp
index 4b21bfe9b74..3cccb66d34c 100644
--- a/core/log/logger.cpp
+++ b/core/log/logger.cpp
@@ -75,6 +75,8 @@ constexpr Logger::mask_type Logger::linop_factory_generate_completed_mask;
 constexpr Logger::mask_type Logger::criterion_check_started_mask;
 constexpr Logger::mask_type Logger::criterion_check_completed_mask;
 
+constexpr Logger::mask_type Logger::batch_solver_completed_mask;
+
 constexpr Logger::mask_type Logger::iteration_complete_mask;
 
 
diff --git a/core/matrix/batch_struct.hpp b/core/matrix/batch_struct.hpp
index 575c511d051..2e668757b99 100644
--- a/core/matrix/batch_struct.hpp
+++ b/core/matrix/batch_struct.hpp
@@ -56,7 +56,6 @@ struct batch_item {
     int32 stride;
     int32 num_rows;
     int32 num_cols;
-    int32 num_nnz = num_rows * stride;
 };
 
 
@@ -73,14 +72,13 @@ struct uniform_batch {
     int32 stride;
     int32 num_rows;
     int32 num_cols;
-    int32 num_nnz = num_rows * stride;
 
     inline size_type get_num_nnz() const
     {
         return static_cast<size_type>(stride * num_rows);
     }
 
-    inline size_type get_entry_storage() const
+    inline size_type get_storage_size() const
     {
         return get_num_nnz() * sizeof(value_type);
     }
@@ -132,7 +130,7 @@ struct uniform_batch {
         return static_cast<size_type>(stride * num_stored_elems_per_row);
     }
 
-    inline size_type get_entry_storage() const
+    inline size_type get_storage_size() const
     {
         return get_num_nnz() * sizeof(value_type);
     }
diff --git a/core/solver/batch_bicgstab.cpp b/core/solver/batch_bicgstab.cpp
index 41bc91d72dd..03ee9b9888e 100644
--- a/core/solver/batch_bicgstab.cpp
+++ b/core/solver/batch_bicgstab.cpp
@@ -56,7 +56,7 @@ GKO_REGISTER_OPERATION(apply, batch_bicgstab::apply);
 template <typename ValueType>
 void Bicgstab<ValueType>::solver_apply(
     const MultiVector<ValueType>* b, MultiVector<ValueType>* x,
-    log::BatchLogData<remove_complex<ValueType>>* log_data) const
+    log::detail::BatchLogData<remove_complex<ValueType>>* log_data) const
 {
     using MVec = MultiVector<ValueType>;
     const kernels::batch_bicgstab::BicgstabSettings<remove_complex<ValueType>>
diff --git a/core/solver/batch_bicgstab_kernels.hpp b/core/solver/batch_bicgstab_kernels.hpp
index 1c7b955c03f..0fd20ff32b8 100644
--- a/core/solver/batch_bicgstab_kernels.hpp
+++ b/core/solver/batch_bicgstab_kernels.hpp
@@ -100,7 +100,7 @@ inline int local_memory_requirement(const int num_rows, const int num_rhs)
             remove_complex<_type>>& options,                                 \
         const batch::BatchLinOp* a, const batch::BatchLinOp* preconditioner, \
         const batch::MultiVector<_type>* b, batch::MultiVector<_type>* x,    \
-        gko::batch::log::BatchLogData<remove_complex<_type>>& logdata)
+        gko::batch::log::detail::BatchLogData<remove_complex<_type>>& logdata)
 
 
 #define GKO_DECLARE_ALL_AS_TEMPLATES \
diff --git a/core/solver/batch_dispatch.hpp b/core/solver/batch_dispatch.hpp
index 449f54a7cba..f5029c4aaae 100644
--- a/core/solver/batch_dispatch.hpp
+++ b/core/solver/batch_dispatch.hpp
@@ -163,7 +163,7 @@ namespace solver {
 
 
 template <typename DValueType>
-class DummyKernelCaller {
+class KernelCallerInterface {
 public:
     template <typename BatchMatrixType, typename PrecType, typename StopType,
               typename LogType>
@@ -174,30 +174,42 @@ class DummyKernelCaller {
 };
 
 
+namespace log {
+namespace detail {
+/**
+ *
+ * Types of batch loggers available.
+ */
+enum class BatchLogType { simple_convergence_completion };
+
+
+}  // namespace detail
+}  // namespace log
+
+
 /**
  * Handles dispatching to the correct instantiation of a batched solver
  * depending on runtime parameters.
  *
- * @tparam KernelCaller  Class with an interface like DummyKernelCaller,
+ * @tparam ValueType  The user-facing value type.
+ * @tparam KernelCaller  Class with an interface like KernelCallerInterface,
  *   that is responsible for finally calling the templated backend-specific
  *   kernel.
  * @tparam SettingsType  Structure type of options for the particular solver to
  * be used.
- * @tparam ValueType  The user-facing value type.
  */
-template <typename KernelCaller, typename SettingsType, typename ValueType>
+template <typename ValueType, typename KernelCaller, typename SettingsType>
 class BatchSolverDispatch {
 public:
     using value_type = ValueType;
     using device_value_type = DeviceValueType<ValueType>;
     using real_type = remove_complex<value_type>;
 
-    BatchSolverDispatch(const KernelCaller& kernel_caller,
-                        const SettingsType& settings,
-                        const BatchLinOp* const matrix,
-                        const BatchLinOp* const preconditioner,
-                        const log::BatchLogType logger_type =
-                            log::BatchLogType::simple_convergence_completion)
+    BatchSolverDispatch(
+        const KernelCaller& kernel_caller, const SettingsType& settings,
+        const BatchLinOp* const matrix, const BatchLinOp* const preconditioner,
+        const log::detail::BatchLogType logger_type =
+            log::detail::BatchLogType::simple_convergence_completion)
         : caller_{kernel_caller},
           settings_{settings},
           mat_{matrix},
@@ -250,9 +262,10 @@ class BatchSolverDispatch {
         const BatchMatrixType& amat,
         const multi_vector::uniform_batch<const device_value_type>& b_item,
         const multi_vector::uniform_batch<device_value_type>& x_item,
-        log::BatchLogData<real_type>& log_data)
+        batch::log::detail::BatchLogData<real_type>& log_data)
     {
-        if (logger_type_ == log::BatchLogType::simple_convergence_completion) {
+        if (logger_type_ ==
+            log::detail::BatchLogType::simple_convergence_completion) {
             device::batch_log::SimpleFinalLogger<real_type> logger(
                 log_data.res_norms.get_data(), log_data.iter_counts.get_data());
             dispatch_on_preconditioner(logger, amat, b_item, x_item);
@@ -261,19 +274,11 @@ class BatchSolverDispatch {
         }
     }
 
-    /**
-     * Solves a linear system from the given data and kernel caller.
-     *
-     * @note The correct backend-specific get_batch_struct function needs to be
-     * available in the current scope.
-     */
-    void apply(const MultiVector<ValueType>* const b,
-               MultiVector<ValueType>* const x,
-               log::BatchLogData<real_type>& log_data)
+    void dispatch_on_matrix(
+        const multi_vector::uniform_batch<const device_value_type>& b_item,
+        const multi_vector::uniform_batch<device_value_type>& x_item,
+        batch::log::detail::BatchLogData<real_type>& log_data)
     {
-        const auto x_item = device::get_batch_struct(x);
-        const auto b_item = device::get_batch_struct(b);
-
         if (auto batch_mat =
                 dynamic_cast<const batch::matrix::Ell<ValueType, int32>*>(
                     mat_)) {
@@ -289,12 +294,28 @@ class BatchSolverDispatch {
         }
     }
 
+    /**
+     * Solves a linear system from the given data and kernel caller.
+     *
+     * @note The correct backend-specific get_batch_struct function needs to be
+     * available in the current scope.
+     */
+    void apply(const MultiVector<ValueType>* const b,
+               MultiVector<ValueType>* const x,
+               batch::log::detail::BatchLogData<real_type>& log_data)
+    {
+        const auto x_item = device::get_batch_struct(x);
+        const auto b_item = device::get_batch_struct(b);
+
+        dispatch_on_matrix(b_item, x_item, log_data);
+    }
+
 private:
     const KernelCaller caller_;
     const SettingsType settings_;
     const BatchLinOp* mat_;
     const BatchLinOp* precond_;
-    const log::BatchLogType logger_type_;
+    const log::detail::BatchLogType logger_type_;
 };
 
 
@@ -302,13 +323,13 @@ class BatchSolverDispatch {
  * Convenient function to create a dispatcher. Infers most template arguments.
  */
 template <typename ValueType, typename KernelCaller, typename SettingsType>
-BatchSolverDispatch<KernelCaller, SettingsType, ValueType> create_dispatcher(
+BatchSolverDispatch<ValueType, KernelCaller, SettingsType> create_dispatcher(
     const KernelCaller& kernel_caller, const SettingsType& settings,
     const BatchLinOp* const matrix, const BatchLinOp* const preconditioner,
-    const log::BatchLogType logger_type =
-        log::BatchLogType::simple_convergence_completion)
+    const log::detail::BatchLogType logger_type =
+        log::detail::BatchLogType::simple_convergence_completion)
 {
-    return BatchSolverDispatch<KernelCaller, SettingsType, ValueType>(
+    return BatchSolverDispatch<ValueType, KernelCaller, SettingsType>(
         kernel_caller, settings, matrix, preconditioner, logger_type);
 }
 
@@ -317,4 +338,5 @@ BatchSolverDispatch<KernelCaller, SettingsType, ValueType> create_dispatcher(
 }  // namespace batch
 }  // namespace gko
 
+
 #endif  // GKO_CORE_SOLVER_BATCH_DISPATCH_HPP_
diff --git a/core/test/solver/batch_bicgstab.cpp b/core/test/solver/batch_bicgstab.cpp
index ccbb924f1bd..4cf55f871b6 100644
--- a/core/test/solver/batch_bicgstab.cpp
+++ b/core/test/solver/batch_bicgstab.cpp
@@ -61,7 +61,7 @@ class BatchBicgstab : public ::testing::Test {
     BatchBicgstab()
         : exec(gko::ReferenceExecutor::create()),
           mtx(gko::test::generate_3pt_stencil_batch_matrix<Mtx>(
-              this->exec->get_master(), nrows, nbatch)),
+              this->exec->get_master(), num_rows, num_batch_items)),
           solver_factory(Solver::build()
                              .with_default_max_iterations(def_max_iters)
                              .with_default_tolerance(def_abs_res_tol)
@@ -71,8 +71,8 @@ class BatchBicgstab : public ::testing::Test {
     {}
 
     std::shared_ptr<const gko::Executor> exec;
-    const gko::size_type nbatch = 3;
-    const int nrows = 5;
+    const gko::size_type num_batch_items = 3;
+    const int num_rows = 5;
     std::shared_ptr<Mtx> mtx;
     std::unique_ptr<typename Solver::Factory> solver_factory;
     const int def_max_iters = 100;
@@ -94,12 +94,10 @@ TYPED_TEST(BatchBicgstab, FactoryKnowsItsExecutor)
 TYPED_TEST(BatchBicgstab, FactoryCreatesCorrectSolver)
 {
     using Solver = typename TestFixture::Solver;
-    for (size_t i = 0; i < this->nbatch; i++) {
-        ASSERT_EQ(this->solver->get_common_size(),
-                  gko::dim<2>(this->nrows, this->nrows));
-    }
+    ASSERT_EQ(this->solver->get_common_size(),
+              gko::dim<2>(this->num_rows, this->num_rows));
 
-    auto solver = static_cast<Solver*>(this->solver.get());
+    auto solver = gko::as<Solver>(this->solver.get());
 
     ASSERT_NE(solver->get_system_matrix(), nullptr);
     ASSERT_EQ(solver->get_system_matrix(), this->mtx);
@@ -114,10 +112,11 @@ TYPED_TEST(BatchBicgstab, CanBeCopied)
 
     copy->copy_from(this->solver.get());
 
-    ASSERT_EQ(copy->get_common_size(), gko::dim<2>(this->nrows, this->nrows));
-    ASSERT_EQ(copy->get_num_batch_items(), this->nbatch);
-    auto copy_mtx = static_cast<Solver*>(copy.get())->get_system_matrix();
-    const auto copy_batch_mtx = static_cast<const Mtx*>(copy_mtx.get());
+    ASSERT_EQ(copy->get_common_size(),
+              gko::dim<2>(this->num_rows, this->num_rows));
+    ASSERT_EQ(copy->get_num_batch_items(), this->num_batch_items);
+    auto copy_mtx = gko::as<Solver>(copy.get())->get_system_matrix();
+    const auto copy_batch_mtx = gko::as<const Mtx>(copy_mtx.get());
     GKO_ASSERT_BATCH_MTX_NEAR(this->mtx.get(), copy_batch_mtx, 0.0);
 }
 
@@ -130,10 +129,11 @@ TYPED_TEST(BatchBicgstab, CanBeMoved)
 
     copy->move_from(this->solver);
 
-    ASSERT_EQ(copy->get_common_size(), gko::dim<2>(this->nrows, this->nrows));
-    ASSERT_EQ(copy->get_num_batch_items(), this->nbatch);
-    auto copy_mtx = static_cast<Solver*>(copy.get())->get_system_matrix();
-    const auto copy_batch_mtx = static_cast<const Mtx*>(copy_mtx.get());
+    ASSERT_EQ(copy->get_common_size(),
+              gko::dim<2>(this->num_rows, this->num_rows));
+    ASSERT_EQ(copy->get_num_batch_items(), this->num_batch_items);
+    auto copy_mtx = gko::as<Solver>(copy.get())->get_system_matrix();
+    const auto copy_batch_mtx = gko::as<const Mtx>(copy_mtx.get());
     GKO_ASSERT_BATCH_MTX_NEAR(this->mtx.get(), copy_batch_mtx, 0.0);
 }
 
@@ -145,10 +145,11 @@ TYPED_TEST(BatchBicgstab, CanBeCloned)
 
     auto clone = this->solver->clone();
 
-    ASSERT_EQ(clone->get_common_size(), gko::dim<2>(this->nrows, this->nrows));
-    ASSERT_EQ(clone->get_num_batch_items(), this->nbatch);
-    auto clone_mtx = static_cast<Solver*>(clone.get())->get_system_matrix();
-    const auto clone_batch_mtx = static_cast<const Mtx*>(clone_mtx.get());
+    ASSERT_EQ(clone->get_common_size(),
+              gko::dim<2>(this->num_rows, this->num_rows));
+    ASSERT_EQ(clone->get_num_batch_items(), this->num_batch_items);
+    auto clone_mtx = gko::as<Solver>(clone.get())->get_system_matrix();
+    const auto clone_batch_mtx = gko::as<const Mtx>(clone_mtx.get());
     GKO_ASSERT_BATCH_MTX_NEAR(this->mtx.get(), clone_batch_mtx, 0.0);
 }
 
@@ -160,8 +161,7 @@ TYPED_TEST(BatchBicgstab, CanBeCleared)
     this->solver->clear();
 
     ASSERT_EQ(this->solver->get_num_batch_items(), 0);
-    auto solver_mtx =
-        static_cast<Solver*>(this->solver.get())->get_system_matrix();
+    auto solver_mtx = gko::as<Solver>(this->solver.get())->get_system_matrix();
     ASSERT_EQ(solver_mtx, nullptr);
 }
 
diff --git a/core/test/utils/batch_helpers.hpp b/core/test/utils/batch_helpers.hpp
index 7a874677c86..abdc3776603 100644
--- a/core/test/utils/batch_helpers.hpp
+++ b/core/test/utils/batch_helpers.hpp
@@ -199,7 +199,6 @@ compute_residual_norms(
     using real_vec = batch::MultiVector<remove_complex<value_type>>;
     auto exec = mtx->get_executor();
     auto num_batch_items = x->get_num_batch_items();
-    auto num_rows = x->get_common_size()[0];
     auto num_rhs = x->get_common_size()[1];
     const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs));
 
@@ -221,7 +220,13 @@ struct Result {
 
     std::shared_ptr<multi_vec> x;
     std::shared_ptr<real_vec> res_norm;
-    std::unique_ptr<gko::batch::log::BatchLogData<remove_complex<ValueType>>>
+};
+
+
+template <typename ValueType>
+struct ResultWithLogData : public Result<ValueType> {
+    std::unique_ptr<
+        gko::batch::log::detail::BatchLogData<remove_complex<ValueType>>>
         log_data;
 };
 
@@ -255,9 +260,9 @@ Result<typename MatrixType::value_type> solve_linear_system(
 }
 
 
-template <typename MatrixType, typename SolveFunction, typename Settings>
-Result<typename MatrixType::value_type> solve_linear_system(
-    std::shared_ptr<const Executor> exec, SolveFunction solve_function,
+template <typename MatrixType, typename SolveLambda, typename Settings>
+ResultWithLogData<typename MatrixType::value_type> solve_linear_system(
+    std::shared_ptr<const Executor> exec, SolveLambda solve_lambda,
     const Settings settings, const LinearSystem<MatrixType>& sys,
     std::shared_ptr<batch::BatchLinOpFactory> precond_factory = nullptr)
 {
@@ -269,17 +274,15 @@ Result<typename MatrixType::value_type> solve_linear_system(
     const size_type num_batch_items = sys.matrix->get_num_batch_items();
     const int num_rows = sys.matrix->get_common_size()[0];
     const int num_rhs = sys.rhs->get_common_size()[1];
-    const gko::batch_dim<2> vec_size(num_batch_items,
-                                     gko::dim<2>(num_rows, num_rhs));
     const gko::batch_dim<2> norm_size(num_batch_items, gko::dim<2>(1, num_rhs));
 
-    Result<value_type> result;
-    // Initialize r to the original unscaled b
+    ResultWithLogData<value_type> result;
     result.x = multi_vec::create_with_config_of(sys.rhs);
     result.x->fill(zero<value_type>());
 
-    auto log_data = std::make_unique<batch::log::BatchLogData<real_type>>(
-        exec, num_batch_items);
+    auto log_data =
+        std::make_unique<batch::log::detail::BatchLogData<real_type>>(
+            exec, num_batch_items);
 
     std::unique_ptr<gko::batch::BatchLinOp> precond;
     if (precond_factory) {
@@ -288,11 +291,12 @@ Result<typename MatrixType::value_type> solve_linear_system(
         precond = nullptr;
     }
 
-    solve_function(settings, precond.get(), sys.matrix.get(), sys.rhs.get(),
-                   result.x.get(), *log_data.get());
+    solve_lambda(settings, precond.get(), sys.matrix.get(), sys.rhs.get(),
+                 result.x.get(), *log_data.get());
 
-    result.log_data = std::make_unique<batch::log::BatchLogData<real_type>>(
-        exec->get_master());
+    result.log_data =
+        std::make_unique<batch::log::detail::BatchLogData<real_type>>(
+            exec->get_master());
     result.log_data->iter_counts = log_data->iter_counts;
     result.log_data->res_norms = log_data->res_norms;
 
diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu
index fa00bb208af..4f36ed0022d 100644
--- a/cuda/solver/batch_bicgstab_kernels.cu
+++ b/cuda/solver/batch_bicgstab_kernels.cu
@@ -67,7 +67,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
            const batch::BatchLinOp* const precon,
            const batch::MultiVector<ValueType>* const b,
            batch::MultiVector<ValueType>* const x,
-           batch::log::BatchLogData<remove_complex<ValueType>>& logdata)
+           batch::log::detail::BatchLogData<remove_complex<ValueType>>& logdata)
     GKO_NOT_IMPLEMENTED;
 
 GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL);
diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp
index 710c7a78c07..6f82aa8a779 100644
--- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp
+++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp
@@ -64,7 +64,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
            const batch::BatchLinOp* const precon,
            const batch::MultiVector<ValueType>* const b,
            batch::MultiVector<ValueType>* const x,
-           batch::log::BatchLogData<remove_complex<ValueType>>& logdata)
+           batch::log::detail::BatchLogData<remove_complex<ValueType>>& logdata)
     GKO_NOT_IMPLEMENTED;
 
 GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL);
diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp
index 7a52149e21d..8b5abb6a562 100644
--- a/hip/solver/batch_bicgstab_kernels.hip.cpp
+++ b/hip/solver/batch_bicgstab_kernels.hip.cpp
@@ -68,7 +68,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
            const batch::BatchLinOp* const precon,
            const batch::MultiVector<ValueType>* const b,
            batch::MultiVector<ValueType>* const x,
-           batch::log::BatchLogData<remove_complex<ValueType>>& logdata)
+           batch::log::detail::BatchLogData<remove_complex<ValueType>>& logdata)
     GKO_NOT_IMPLEMENTED;
 
 GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL);
diff --git a/include/ginkgo/core/log/batch_logger.hpp b/include/ginkgo/core/log/batch_logger.hpp
index 122467893fd..7c1898dd5a0 100644
--- a/include/ginkgo/core/log/batch_logger.hpp
+++ b/include/ginkgo/core/log/batch_logger.hpp
@@ -50,12 +50,7 @@ namespace batch {
  * @ingroup log
  */
 namespace log {
-
-
-/**
- * Types of batch loggers available.
- */
-enum class BatchLogType { simple_convergence_completion };
+namespace detail {
 
 
 /**
@@ -64,7 +59,7 @@ enum class BatchLogType { simple_convergence_completion };
  * @note Supports only single rhs
  */
 template <typename ValueType>
-struct BatchLogData {
+struct BatchLogData final {
     using real_type = remove_complex<ValueType>;
 
     BatchLogData(std::shared_ptr<const Executor> exec,
@@ -103,6 +98,9 @@ struct BatchLogData {
 };
 
 
+}  // namespace detail
+
+
 /**
  * Logs the final residuals and iteration counts for a batch solver.
  *
@@ -127,6 +125,8 @@ class BatchConvergence : public gko::log::Logger {
     /**
      * Creates a convergence logger. This dynamically allocates the memory,
      * constructs the object and returns an std::unique_ptr to this object.
+     * TODO: See if the objects can be pre-allocated beforehand instead of being
+     * copied in the `on_<>` event
      *
      * @param exec  the executor
      * @param enabled_events  the events enabled for this logger. By default all
@@ -135,11 +135,11 @@ class BatchConvergence : public gko::log::Logger {
      * @return an std::unique_ptr to the the constructed object
      */
     static std::unique_ptr<BatchConvergence> create(
-        std::shared_ptr<const Executor> exec,
-        const mask_type& enabled_events = gko::log::Logger::all_events_mask)
+        const mask_type& enabled_events =
+            gko::log::Logger::batch_solver_completed_mask)
     {
         return std::unique_ptr<BatchConvergence>(
-            new BatchConvergence(exec, enabled_events));
+            new BatchConvergence(enabled_events));
     }
 
     /**
@@ -159,17 +159,14 @@ class BatchConvergence : public gko::log::Logger {
     }
 
 protected:
-    explicit BatchConvergence(
-        std::shared_ptr<const Executor> exec,
-        const mask_type& enabled_events = gko::log::Logger::all_events_mask)
-        : gko::log::Logger(enabled_events),
-          iteration_count_(exec),
-          residual_norm_(exec)
+    explicit BatchConvergence(const mask_type& enabled_events =
+                                  gko::log::Logger::batch_solver_completed_mask)
+        : gko::log::Logger(enabled_events)
     {}
 
 private:
-    mutable array<int> iteration_count_;
-    mutable array<real_type> residual_norm_;
+    mutable array<int> iteration_count_{};
+    mutable array<real_type> residual_norm_{};
 };
 
 
diff --git a/include/ginkgo/core/log/logger.hpp b/include/ginkgo/core/log/logger.hpp
index c16e7efbf0d..5f6d0739012 100644
--- a/include/ginkgo/core/log/logger.hpp
+++ b/include/ginkgo/core/log/logger.hpp
@@ -609,16 +609,30 @@ public:                                                              \
     std::enable_if_t<Event == 26 && (26 < event_count_max)> on(
         Params&&... params) const
     {
-        if (enabled_events_ & (mask_type{1} << 26)) {
+        if (enabled_events_ & batch_solver_completed_mask) {
             this->on_batch_solver_completed(std::forward<Params>(params)...);
         }
     }
 
 protected:
+    /**
+     * Batch solver's event that records the iteration count and the residual
+     * norm.
+     *
+     * @param iters  the array of iteration counts.
+     * @param residual_norms  the array storing the residual norms.
+     */
     virtual void on_batch_solver_completed(
         const array<int>& iters, const array<double>& residual_norms) const
     {}
 
+    /**
+     * Batch solver's event that records the iteration count and the residual
+     * norm.
+     *
+     * @param iters  the array of iteration counts.
+     * @param residual_norms  the array storing the residual norms.
+     */
     virtual void on_batch_solver_completed(
         const array<int>& iters, const array<float>& residual_norms) const
     {}
diff --git a/include/ginkgo/core/matrix/batch_identity.hpp b/include/ginkgo/core/matrix/batch_identity.hpp
index 668fbcc1527..15b7623ac0f 100644
--- a/include/ginkgo/core/matrix/batch_identity.hpp
+++ b/include/ginkgo/core/matrix/batch_identity.hpp
@@ -48,18 +48,11 @@ namespace matrix {
 
 
 /**
- * Identity is a batch matrix format which explicitly stores all values of
- * the matrix in each of the batches.
- *
- * The values in each of the batches are stored in row-major format (values
- * belonging to the same row appear consecutive in the memory). Optionally, rows
- * can be padded for better memory access.
+ * The batch Identity matrix, which represents a batch of Identity matrices.
  *
  * @tparam ValueType  precision of matrix elements
  *
- * @note While this format is not very useful for storing sparse matrices, it
- *       is often suitable to store vectors, and sets of vectors.
- * @ingroup batch_dense
+ * @ingroup batch_identity
  * @ingroup mat_formats
  * @ingroup BatchLinOp
  */
@@ -81,7 +74,7 @@ class Identity final : public EnableBatchLinOp<Identity<ValueType>>,
 
     /**
      * Apply the matrix to a multi-vector. Represents the matrix vector
-     * multiplication, x = A * b, where x and b are both multi-vectors.
+     * multiplication, x = I * b, where x and b are both multi-vectors.
      *
      * @param b  the multi-vector to be applied to
      * @param x  the output multi-vector
@@ -91,7 +84,7 @@ class Identity final : public EnableBatchLinOp<Identity<ValueType>>,
 
     /**
      * Apply the matrix to a multi-vector with a linear combination of the given
-     * input vector. Represents the matrix vector multiplication, x = alpha * A
+     * input vector. Represents the matrix vector multiplication, x = alpha * I
      * * b + beta * x, where x and b are both multi-vectors.
      *
      * @param alpha  the scalar to scale the matrix-vector product with
diff --git a/include/ginkgo/core/solver/batch_bicgstab.hpp b/include/ginkgo/core/solver/batch_bicgstab.hpp
index 32a0154f602..29b65bc225a 100644
--- a/include/ginkgo/core/solver/batch_bicgstab.hpp
+++ b/include/ginkgo/core/solver/batch_bicgstab.hpp
@@ -94,15 +94,15 @@ class Bicgstab final
 
     explicit Bicgstab(const Factory* factory,
                       std::shared_ptr<const BatchLinOp> system_matrix)
-        : EnableBatchSolver<Bicgstab>(
-              factory->get_executor(), std::move(system_matrix),
-              detail::extract_common_batch_params(factory->get_parameters())),
+        : EnableBatchSolver<Bicgstab>(factory->get_executor(),
+                                      std::move(system_matrix),
+                                      factory->get_parameters()),
           parameters_{factory->get_parameters()}
     {}
 
-    void solver_apply(const MultiVector<ValueType>* b,
-                      MultiVector<ValueType>* x,
-                      log::BatchLogData<real_type>* log_data) const override;
+    void solver_apply(
+        const MultiVector<ValueType>* b, MultiVector<ValueType>* x,
+        log::detail::BatchLogData<real_type>* log_data) const override;
 };
 
 
diff --git a/include/ginkgo/core/solver/batch_solver_base.hpp b/include/ginkgo/core/solver/batch_solver_base.hpp
index c0d5935fa30..2e877d8cb4e 100644
--- a/include/ginkgo/core/solver/batch_solver_base.hpp
+++ b/include/ginkgo/core/solver/batch_solver_base.hpp
@@ -89,7 +89,13 @@ class BatchSolver {
      * @param res_tol  The residual tolerance to be used for subsequent
      *                 invocations of the solver.
      */
-    void set_residual_tolerance(double res_tol) { residual_tol_ = res_tol; }
+    void set_residual_tolerance(double res_tol)
+    {
+        if (res_tol < 0) {
+            GKO_INVALID_STATE("Tolerance cannot be negative!");
+        }
+        residual_tol_ = res_tol;
+    }
 
     /**
      * Get the maximum number of iterations set on the solver.
@@ -106,19 +112,48 @@ class BatchSolver {
      */
     void set_max_iterations(int max_iterations)
     {
+        if (max_iterations < 0) {
+            GKO_INVALID_STATE("Max iterations cannot be negative!");
+        }
         max_iterations_ = max_iterations;
     }
 
+    /**
+     * Get the tolerance type.
+     *
+     * @return  The tolerance type.
+     */
+    ::gko::batch::stop::ToleranceType get_tolerance_type() const
+    {
+        return tol_type_;
+    }
+
+    /**
+     * Set the type of tolerance check to use inside the solver
+     *
+     * @param tol_type  The tolerance type.
+     */
+    void set_tolerance_type(::gko::batch::stop::ToleranceType tol_type)
+    {
+        if (tol_type != ::gko::batch::stop::ToleranceType::absolute ||
+            tol_type != ::gko::batch::stop::ToleranceType::relative) {
+            GKO_INVALID_STATE("Invalid tolerance type specified!");
+        }
+        tol_type_ = tol_type;
+    }
+
 protected:
     BatchSolver() {}
 
     BatchSolver(std::shared_ptr<const BatchLinOp> system_matrix,
                 std::shared_ptr<const BatchLinOp> gen_preconditioner,
-                const double res_tol, const int max_iterations)
+                const double res_tol, const int max_iterations,
+                const ::gko::batch::stop::ToleranceType tol_type)
         : system_matrix_{std::move(system_matrix)},
           preconditioner_{std::move(gen_preconditioner)},
           residual_tol_{res_tol},
           max_iterations_{max_iterations},
+          tol_type_{tol_type},
           workspace_{}
     {}
 
@@ -126,32 +161,11 @@ class BatchSolver {
     std::shared_ptr<const BatchLinOp> preconditioner_{};
     double residual_tol_{};
     int max_iterations_{};
+    ::gko::batch::stop::ToleranceType tol_type_{};
     mutable array<unsigned char> workspace_{};
 };
 
 
-namespace detail {
-
-
-struct common_batch_params {
-    std::shared_ptr<const BatchLinOpFactory> prec_factory;
-    std::shared_ptr<const BatchLinOp> generated_prec;
-    double residual_tolerance;
-    int max_iterations;
-};
-
-
-template <typename ParamsType>
-common_batch_params extract_common_batch_params(ParamsType& params)
-{
-    return {params.preconditioner, params.generated_preconditioner,
-            params.default_tolerance, params.default_max_iterations};
-}
-
-
-}  // namespace detail
-
-
 /**
  * The parameter type shared between all preconditioned iterative solvers,
  * excluding the parameters available in iterative_solver_factory_parameters.
@@ -301,11 +315,12 @@ class EnableBatchSolver
         : EnableBatchLinOp<ConcreteSolver, PolymorphicBase>(std::move(exec))
     {}
 
+    template <typename FactoryParameters>
     explicit EnableBatchSolver(std::shared_ptr<const Executor> exec,
                                std::shared_ptr<const BatchLinOp> system_matrix,
-                               detail::common_batch_params common_params)
-        : BatchSolver(system_matrix, nullptr, common_params.residual_tolerance,
-                      common_params.max_iterations),
+                               const FactoryParameters& params)
+        : BatchSolver(system_matrix, nullptr, params.default_tolerance,
+                      params.default_max_iterations, params.tolerance_type),
           EnableBatchLinOp<ConcreteSolver, PolymorphicBase>(
               exec, gko::transpose(system_matrix->get_size()))
     {
@@ -315,13 +330,12 @@ class EnableBatchSolver
         using Identity = matrix::Identity<value_type>;
         using real_type = remove_complex<value_type>;
 
-        if (common_params.generated_prec) {
-            GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(common_params.generated_prec,
+        if (params.generated_preconditioner) {
+            GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(params.generated_preconditioner,
                                               this);
-            preconditioner_ = std::move(common_params.generated_prec);
-        } else if (common_params.prec_factory) {
-            preconditioner_ =
-                common_params.prec_factory->generate(system_matrix_);
+            preconditioner_ = std::move(params.generated_preconditioner);
+        } else if (params.preconditioner) {
+            preconditioner_ = params.preconditioner->generate(system_matrix_);
         } else {
             auto id = Identity::create(exec, system_matrix->get_size());
             preconditioner_ = std::move(id);
@@ -341,7 +355,7 @@ class EnableBatchSolver
         if (b->get_common_size()[1] > 1) {
             GKO_NOT_IMPLEMENTED;
         }
-        auto log_data_ = std::make_unique<log::BatchLogData<real_type>>(
+        auto log_data_ = std::make_unique<log::detail::BatchLogData<real_type>>(
             exec, b->get_num_batch_items(), workspace_);
 
         this->solver_apply(b, x, log_data_.get());
@@ -361,9 +375,9 @@ class EnableBatchSolver
         x->add_scaled(alpha, x_clone.get());
     }
 
-    virtual void solver_apply(const MultiVector<ValueType>* b,
-                              MultiVector<ValueType>* x,
-                              log::BatchLogData<real_type>* info) const = 0;
+    virtual void solver_apply(
+        const MultiVector<ValueType>* b, MultiVector<ValueType>* x,
+        log::detail::BatchLogData<real_type>* info) const = 0;
 };
 
 
diff --git a/include/ginkgo/core/stop/batch_stop_enum.hpp b/include/ginkgo/core/stop/batch_stop_enum.hpp
index d960e384d24..3199392cf3e 100644
--- a/include/ginkgo/core/stop/batch_stop_enum.hpp
+++ b/include/ginkgo/core/stop/batch_stop_enum.hpp
@@ -39,6 +39,18 @@ namespace batch {
 namespace stop {
 
 
+/**
+ * This enum provides two types of options for the convergence of an iterative
+ * solver.
+ *
+ * `absolute` tolerance implies that the convergence criteria check is
+ * against the computed residual ($||r|| <= \tau$, where $||r||$ may be implicit
+ * or explicit depending on the solver).
+ *
+ * With the `relative` tolerance type, the solver
+ * convergence criteria checks against the relative residual norm
+ * ($\frac{||r||}{||b||} < \tau$, where $||b||$$ is the L2 norm of the rhs).
+ */
 enum class ToleranceType { absolute, relative };
 
 
diff --git a/omp/solver/batch_bicgstab_kernels.cpp b/omp/solver/batch_bicgstab_kernels.cpp
index 207ae042a4c..822c8820551 100644
--- a/omp/solver/batch_bicgstab_kernels.cpp
+++ b/omp/solver/batch_bicgstab_kernels.cpp
@@ -100,12 +100,11 @@ class KernelCaller {
             // TODO: Align to cache line boundary
             // TODO: Allocate and free once per thread rather than once per
             // work-item.
-            const auto local_space =
-                static_cast<unsigned char*>(malloc(local_size_bytes));
+            auto local_space = array<unsigned char>(exec_, local_size_bytes);
             batch_entry_bicgstab_impl<StopType, PrecondType, LogType,
                                       BatchMatrixType, ValueType>(
-                settings_, logger, precond, mat, b, x, batch_id, local_space);
-            free(local_space);
+                settings_, logger, precond, mat, b, x, batch_id,
+                local_space.get_data());
         }
     }
 
@@ -122,7 +121,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
            const batch::BatchLinOp* const precond,
            const batch::MultiVector<ValueType>* const b,
            batch::MultiVector<ValueType>* const x,
-           batch::log::BatchLogData<remove_complex<ValueType>>& logdata)
+           batch::log::detail::BatchLogData<remove_complex<ValueType>>& logdata)
 {
     auto dispatcher = batch::solver::create_dispatcher<ValueType>(
         KernelCaller<ValueType>(exec, settings), settings, mat, precond);
diff --git a/reference/log/batch_logger.hpp b/reference/log/batch_logger.hpp
index e9dadb56ddc..0b1be52e1f4 100644
--- a/reference/log/batch_logger.hpp
+++ b/reference/log/batch_logger.hpp
@@ -44,7 +44,7 @@ namespace batch_log {
 
 
 /**
- * Logs the final residual and iteration count for a batch solver.
+ * Logs the final residual norm and iteration count for a batch solver.
  *
  * @note Supports only a single RHS per batch item.
  */
@@ -66,11 +66,11 @@ class SimpleFinalLogger final {
     {}
 
     /**
-     * Logs the iteration count and residual norm.
+     * Logs the final iteration count and the final residual norm.
      *
      * @param batch_idx  The index of linear system in the batch to log.
-     * @param iter  The current iteration count (0-based).
-     * @param res_norm  Norm of current residual
+     * @param iter  The final iteration count (0-based).
+     * @param res_norm  Norm of final residual norm
      */
     void log_iteration(const size_type batch_idx, const int iter,
                        const real_type res_norm)
diff --git a/reference/solver/batch_bicgstab_kernels.cpp b/reference/solver/batch_bicgstab_kernels.cpp
index 5b5d80794ad..b35b28c2cbf 100644
--- a/reference/solver/batch_bicgstab_kernels.cpp
+++ b/reference/solver/batch_bicgstab_kernels.cpp
@@ -119,7 +119,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
            const batch::BatchLinOp* const precon,
            const batch::MultiVector<ValueType>* const b,
            batch::MultiVector<ValueType>* const x,
-           batch::log::BatchLogData<remove_complex<ValueType>>& log_data)
+           batch::log::detail::BatchLogData<remove_complex<ValueType>>& log_data)
 {
     auto dispatcher = batch::solver::create_dispatcher<ValueType>(
         KernelCaller<ValueType>(exec, settings), settings, mat, precon);
diff --git a/reference/solver/batch_bicgstab_kernels.hpp.inc b/reference/solver/batch_bicgstab_kernels.hpp.inc
index 0bf38890fe2..0a281b34d49 100644
--- a/reference/solver/batch_bicgstab_kernels.hpp.inc
+++ b/reference/solver/batch_bicgstab_kernels.hpp.inc
@@ -274,6 +274,8 @@ inline void batch_entry_bicgstab_impl(
 
     for (iter = 0; iter < settings.max_iterations; iter++) {
         if (stop.check_converged(res_norms_entry.values)) {
+            logger.log_iteration(batch_item_id, iter,
+                                 res_norms_entry.values[0]);
             break;
         }
 
@@ -313,13 +315,12 @@ inline void batch_entry_bicgstab_impl(
                                         res_norms_entry);
 
         if (stop.check_converged(res_norms_entry.values)) {
-            // update x for the systems (rhs) which converge at this point...  x
-            // = x + alpha*p_hat
-            // note bits could change from 0 to 1, not the other way round, so
-            // we can use xor to get info about recent convergence...
-            // const uint32 converged_recent = converged_prev ^ converged;
+            // update x for the systems
+            // x = x + alpha * p_hat
             update_x_middle(gko::batch::to_const(alpha_entry),
                             gko::batch::to_const(p_hat_entry), x_entry);
+            logger.log_iteration(batch_item_id, iter,
+                                 res_norms_entry.values[0]);
             break;
         }
 
diff --git a/reference/test/solver/batch_bicgstab_kernels.cpp b/reference/test/solver/batch_bicgstab_kernels.cpp
index 839f3c6961d..93d34befe91 100644
--- a/reference/test/solver/batch_bicgstab_kernels.cpp
+++ b/reference/test/solver/batch_bicgstab_kernels.cpp
@@ -64,7 +64,7 @@ class BatchBicgstab : public ::testing::Test {
     using MVec = gko::batch::MultiVector<value_type>;
     using RealMVec = gko::batch::MultiVector<real_type>;
     using Settings = gko::kernels::batch_bicgstab::BicgstabSettings<real_type>;
-    using LogData = gko::batch::log::BatchLogData<real_type>;
+    using LogData = gko::batch::log::detail::BatchLogData<real_type>;
     using LinSys = gko::test::LinearSystem<Mtx>;
 
     BatchBicgstab()
@@ -181,7 +181,7 @@ TYPED_TEST(BatchBicgstab, StencilSystemLoggerLogsIterations)
     auto res = gko::test::solve_linear_system(
         this->exec, this->solve_lambda, solver_settings, this->linear_system);
 
-    const int* const iter_array = res.log_data->iter_counts.get_const_data();
+    auto iter_array = res.log_data->iter_counts.get_const_data();
     for (size_t i = 0; i < this->num_batch_items; i++) {
         ASSERT_EQ(iter_array[i], ref_iters);
     }
@@ -239,7 +239,7 @@ TYPED_TEST(BatchBicgstab, ApplyLogsResAndIters)
     const int num_rows = 13;
     const size_t num_batch_items = 5;
     const int num_rhs = 1;
-    std::shared_ptr<Logger> logger = Logger::create(this->exec);
+    std::shared_ptr<Logger> logger = Logger::create();
     auto linear_system = gko::test::generate_3pt_stencil_batch_problem<Mtx>(
         this->exec, num_batch_items, num_rows, num_rhs);
     auto solver = gko::share(solver_factory->generate(linear_system.matrix));
diff --git a/test/solver/batch_bicgstab_kernels.cpp b/test/solver/batch_bicgstab_kernels.cpp
index e29d20cad83..f96e2c0948c 100644
--- a/test/solver/batch_bicgstab_kernels.cpp
+++ b/test/solver/batch_bicgstab_kernels.cpp
@@ -63,7 +63,7 @@ class BatchBicgstab : public CommonTestFixture {
     using MVec = gko::batch::MultiVector<value_type>;
     using RealMVec = gko::batch::MultiVector<real_type>;
     using Settings = gko::kernels::batch_bicgstab::BicgstabSettings<real_type>;
-    using LogData = gko::batch::log::BatchLogData<real_type>;
+    using LogData = gko::batch::log::detail::BatchLogData<real_type>;
     using Logger = gko::batch::log::BatchConvergence<real_type>;
 
     BatchBicgstab() {}
@@ -207,7 +207,7 @@ TEST_F(BatchBicgstab, CanSolveLargeHpdSystem)
             .with_default_tolerance(tol)
             .with_tolerance_type(gko::batch::stop::ToleranceType::absolute)
             .on(exec);
-    std::shared_ptr<Logger> logger = Logger::create(exec);
+    std::shared_ptr<Logger> logger = Logger::create();
     auto linear_system = gko::test::generate_diag_dominant_batch_problem<Mtx>(
         exec, num_batch_items, num_rows, num_rhs, true);
     auto solver = gko::share(solver_factory->generate(linear_system.matrix));