diff --git a/dpcpp/matrix/batch_csr_kernels.dp.cpp b/dpcpp/matrix/batch_csr_kernels.dp.cpp index c4281d81b1b..2871c90f0c4 100644 --- a/dpcpp/matrix/batch_csr_kernels.dp.cpp +++ b/dpcpp/matrix/batch_csr_kernels.dp.cpp @@ -97,17 +97,18 @@ void simple_apply(std::shared_ptr exec, // Launch a kernel that has nbatches blocks, each block has max group size exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - simple_apply_kernel(mat_b, b_b.values, x_b.values, item_ct1); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + simple_apply_kernel(mat_b, b_b.values, x_b.values, + item_ct1); + }); }); } @@ -145,22 +146,23 @@ void advanced_apply(std::shared_ptr exec, // Launch a kernel that has nbatches blocks, each block has max group size exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto alpha_b = - batch::extract_batch_item(alpha_ub, group_id); - const auto beta_b = - batch::extract_batch_item(beta_ub, group_id); - advanced_apply_kernel(alpha_b.values[0], mat_b, b_b.values, - beta_b.values[0], x_b.values, item_ct1); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto alpha_b = + batch::extract_batch_item(alpha_ub, group_id); + const auto beta_b = + batch::extract_batch_item(beta_ub, group_id); + advanced_apply_kernel(alpha_b.values[0], mat_b, b_b.values, + beta_b.values[0], x_b.values, + item_ct1); + }); }); } diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 335cc692264..6a4509c8f77 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -91,10 +91,9 @@ class KernelCaller { slm_values(sycl::range<1>(shared_size), cgh); cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( - subgroup_size)]] [ - [intel::kernel_args_restrict]] { + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + subgroup_size)]] [[intel::kernel_args_restrict]] { auto batch_id = item_ct1.get_group_linear_id(); const auto mat_global_entry = gko::batch::matrix::extract_batch_item(mat, batch_id); diff --git a/include/ginkgo/core/matrix/batch_csr.hpp b/include/ginkgo/core/matrix/batch_csr.hpp index 63916271263..fad5148e9ed 100644 --- a/include/ginkgo/core/matrix/batch_csr.hpp +++ b/include/ginkgo/core/matrix/batch_csr.hpp @@ -221,8 +221,8 @@ class Csr final * significantly more memory efficient than the non-constant version, * so always prefer this version. */ - const value_type* get_const_values_for_item(size_type batch_id) const - noexcept + const value_type* get_const_values_for_item( + size_type batch_id) const noexcept { GKO_ASSERT(batch_id < this->get_num_batch_items()); GKO_ASSERT(values_.get_num_elems() >= diff --git a/include/ginkgo/ginkgo.hpp b/include/ginkgo/ginkgo.hpp index 62be8aaa394..ab0829625a6 100644 --- a/include/ginkgo/ginkgo.hpp +++ b/include/ginkgo/ginkgo.hpp @@ -81,6 +81,7 @@ #include #include +#include #include #include #include