Skip to content

Commit

Permalink
Add batch::Csr to batch::Bicgstab dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Nov 6, 2023
1 parent 87093cc commit 65d4ac9
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 3 deletions.
4 changes: 4 additions & 0 deletions core/solver/batch_dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ class batch_solver_dispatch {
mat_)) {
auto mat_item = device::get_batch_struct(batch_mat);
dispatch_on_logger(mat_item, b_item, x_item, log_data);
} else if (auto batch_mat = dynamic_cast<
const batch::matrix::Csr<ValueType, int32>*>(mat_)) {
auto mat_item = device::get_batch_struct(batch_mat);
dispatch_on_logger(mat_item, b_item, x_item, log_data);
} else {
GKO_NOT_SUPPORTED(mat_);
}
Expand Down
1 change: 1 addition & 0 deletions cuda/solver/batch_bicgstab_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ namespace batch_bicgstab {

#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc"
#include "common/cuda_hip/components/uninitialized_array.hpp.inc"
#include "common/cuda_hip/matrix/batch_csr_kernels.hpp.inc"
#include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc"
#include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc"
#include "common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc"
Expand Down
8 changes: 5 additions & 3 deletions dpcpp/solver/batch_bicgstab_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ namespace batch_bicgstab {


#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc"
#include "dpcpp/matrix/batch_csr_kernels.hpp.inc"
#include "dpcpp/matrix/batch_dense_kernels.hpp.inc"
#include "dpcpp/matrix/batch_ell_kernels.hpp.inc"
#include "dpcpp/solver/batch_bicgstab_kernels.hpp.inc"
Expand Down Expand Up @@ -118,9 +119,10 @@ class KernelCaller {
slm_values(sycl::range<1>(shared_size), cgh);

cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(
subgroup_size)]] [[intel::kernel_args_restrict]] {
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(
subgroup_size)]] [
[intel::kernel_args_restrict]] {
auto batch_id = item_ct1.get_group_linear_id();
const auto mat_global_entry =
gko::batch::matrix::extract_batch_item(mat, batch_id);
Expand Down
1 change: 1 addition & 0 deletions hip/solver/batch_bicgstab_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ namespace batch_bicgstab {

#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc"
#include "common/cuda_hip/components/uninitialized_array.hpp.inc"
#include "common/cuda_hip/matrix/batch_csr_kernels.hpp.inc"
#include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc"
#include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc"
#include "common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc"
Expand Down
1 change: 1 addition & 0 deletions omp/solver/batch_bicgstab_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ constexpr int max_num_rhs = 1;


#include "reference/base/batch_multi_vector_kernels.hpp.inc"
#include "reference/matrix/batch_csr_kernels.hpp.inc"
#include "reference/matrix/batch_dense_kernels.hpp.inc"
#include "reference/matrix/batch_ell_kernels.hpp.inc"
#include "reference/solver/batch_bicgstab_kernels.hpp.inc"
Expand Down
1 change: 1 addition & 0 deletions reference/solver/batch_bicgstab_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ constexpr int max_num_rhs = 1;


#include "reference/base/batch_multi_vector_kernels.hpp.inc"
#include "reference/matrix/batch_csr_kernels.hpp.inc"
#include "reference/matrix/batch_dense_kernels.hpp.inc"
#include "reference/matrix/batch_ell_kernels.hpp.inc"
#include "reference/solver/batch_bicgstab_kernels.hpp.inc"
Expand Down
38 changes: 38 additions & 0 deletions reference/test/solver/batch_bicgstab_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/log/batch_logger.hpp>
#include <ginkgo/core/matrix/batch_csr.hpp>
#include <ginkgo/core/matrix/batch_dense.hpp>
#include <ginkgo/core/matrix/batch_ell.hpp>

Expand All @@ -61,6 +62,7 @@ class BatchBicgstab : public ::testing::Test {
using solver_type = gko::batch::solver::Bicgstab<value_type>;
using Mtx = gko::batch::matrix::Dense<value_type>;
using EllMtx = gko::batch::matrix::Ell<value_type>;
using CsrMtx = gko::batch::matrix::Csr<value_type>;
using MVec = gko::batch::MultiVector<value_type>;
using RealMVec = gko::batch::MultiVector<real_type>;
using Settings = gko::kernels::batch_bicgstab::settings<real_type>;
Expand Down Expand Up @@ -274,6 +276,42 @@ TYPED_TEST(BatchBicgstab, CanSolveEllSystem)
}


TYPED_TEST(BatchBicgstab, CanSolveCsrSystem)
{
using value_type = typename TestFixture::value_type;
using real_type = gko::remove_complex<value_type>;
using Solver = typename TestFixture::solver_type;
using Mtx = typename TestFixture::CsrMtx;
const real_type tol = 1e-5;
const int max_iters = 1000;
auto solver_factory =
Solver::build()
.with_max_iterations(max_iters)
.with_tolerance(tol)
.with_tolerance_type(gko::batch::stop::tolerance_type::relative)
.on(this->exec);
const int num_rows = 13;
const size_t num_batch_items = 2;
const int num_rhs = 1;
auto stencil_mat =
gko::share(gko::test::generate_3pt_stencil_batch_matrix<Mtx>(
this->exec, num_batch_items, num_rows, (num_rows * 3 - 2)));
auto linear_system =
gko::test::generate_batch_linear_system(stencil_mat, num_rhs);
auto solver = gko::share(solver_factory->generate(linear_system.matrix));

auto res =
gko::test::solve_linear_system(this->exec, linear_system, solver);

GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10);
for (size_t i = 0; i < num_batch_items; i++) {
ASSERT_LE(res.host_res_norm->get_const_values()[i] /
linear_system.host_rhs_norm->get_const_values()[i],
tol * 10);
}
}


TYPED_TEST(BatchBicgstab, CanSolveDenseHpdSystem)
{
using value_type = typename TestFixture::value_type;
Expand Down

0 comments on commit 65d4ac9

Please sign in to comment.