diff --git a/cuda/solver/batch_cg_kernels.cu b/cuda/solver/batch_cg_kernels.cu index 541f3f6b936..161e7ee3639 100644 --- a/cuda/solver/batch_cg_kernels.cu +++ b/cuda/solver/batch_cg_kernels.cu @@ -206,6 +206,11 @@ public: sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; + case 5: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; default: GKO_NOT_IMPLEMENTED; } diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 3c15c94df71..231d89d0f9c 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -59,10 +59,10 @@ __dpct_inline__ int get_group_size(int value, template -class KernelCaller { +class kernel_caller { public: - KernelCaller(std::shared_ptr exec, - const settings> settings) + kernel_caller(std::shared_ptr exec, + const settings> settings) : exec_{std::move(exec)}, settings_{settings} {} @@ -167,8 +167,7 @@ class KernelCaller { int n_shared_total = sconf.n_shared + int(sconf.prec_shared); // template - // launch_apply_kernel + // launch_apply_kernel if (num_rows <= 32 && n_shared_total == 10) { launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, @@ -256,7 +255,7 @@ void apply(std::shared_ptr exec, batch::log::detail::log_data>& logdata) { auto dispatcher = batch::solver::create_dispatcher( - KernelCaller(exec, settings), settings, mat, precond); + kernel_caller(exec, settings), settings, mat, precond); dispatcher.apply(b, x, logdata); } diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index e7cbf798b1b..90a5fee0e81 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -19,7 +19,6 @@ __dpct_inline__ void initialize( const auto sg_id = sg.get_group_id(); const auto tid = item_ct1.get_local_linear_id(); const auto group_size = item_ct1.get_local_range().size(); - const auto group = item_ct1.get_group(); rho_old = one(); omega = one(); @@ -296,11 +295,15 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0], item_ct1); item_ct1.barrier(sycl::access::fence_space::global_and_local); + int iter = 0; // stopping criterion object StopType stop(tol, norms_rhs_sh); + if (stop.check_converged(norms_res_sh)) { + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + return; + } - int iter = 0; for (; iter < max_iter; iter++) { if (stop.check_converged(norms_res_sh)) { logger.log_iteration(batch_id, iter, norms_res_sh[0]); diff --git a/dpcpp/solver/batch_cg_kernels.dp.cpp b/dpcpp/solver/batch_cg_kernels.dp.cpp index 922c4baebda..cbc803dcbdc 100644 --- a/dpcpp/solver/batch_cg_kernels.dp.cpp +++ b/dpcpp/solver/batch_cg_kernels.dp.cpp @@ -43,12 +43,185 @@ namespace batch_cg { #include "dpcpp/matrix/batch_csr_kernels.hpp.inc" #include "dpcpp/matrix/batch_dense_kernels.hpp.inc" #include "dpcpp/matrix/batch_ell_kernels.hpp.inc" +#include "dpcpp/solver/batch_cg_kernels.hpp.inc" template using settings = gko::kernels::batch_cg::settings; +__dpct_inline__ int get_group_size(int value, + int subgroup_size = config::warp_size) +{ + int num_sg = ceildiv(value, subgroup_size); + return num_sg * subgroup_size; +} + + +template +class kernel_caller { +public: + kernel_caller(std::shared_ptr exec, + const settings> settings) + : exec_{std::move(exec)}, settings_{settings} + {} + + template + __dpct_inline__ void launch_apply_kernel( + const gko::kernels::batch_cg::storage_config& sconf, LogType& logger, + PrecType& prec, const BatchMatrixType mat, + const ValueType* const __restrict__ b_values, + ValueType* const __restrict__ x_values, + ValueType* const __restrict__ workspace, const int& group_size, + const int& shared_size) const + { + auto num_rows = mat.num_rows; + + const dim3 block(group_size); + const dim3 grid(mat.num_batch_items); + + auto max_iters = settings_.max_iterations; + auto res_tol = settings_.residual_tol; + + exec_->get_queue()->submit([&](sycl::handler& cgh) { + sycl::accessor + slm_values(sycl::range<1>(shared_size), cgh); + + cgh.parallel_for( + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + subgroup_size)]] [[intel::kernel_args_restrict]] { + auto batch_id = item_ct1.get_group_linear_id(); + const auto mat_global_entry = + gko::batch::matrix::extract_batch_item(mat, batch_id); + const ValueType* const b_global_entry = + gko::batch::multi_vector::batch_item_ptr( + b_values, 1, num_rows, batch_id); + ValueType* const x_global_entry = + gko::batch::multi_vector::batch_item_ptr( + x_values, 1, num_rows, batch_id); + apply_kernel( + sconf, max_iters, res_tol, logger, prec, + mat_global_entry, b_global_entry, x_global_entry, + num_rows, mat.get_single_item_num_nnz(), + static_cast(slm_values.get_pointer()), + item_ct1, workspace); + }); + }); + } + + template + void call_kernel( + LogType logger, const BatchMatrixType& mat, PrecType prec, + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const + { + using real_type = typename gko::remove_complex; + const size_type num_batch_items = mat.num_batch_items; + const auto num_rows = mat.num_rows; + const auto num_rhs = b.num_rhs; + GKO_ASSERT(num_rhs == 1); + + auto device = exec_->get_queue()->get_device(); + int group_size = + device.get_info(); + if (group_size > num_rows) { + group_size = get_group_size(num_rows); + } + group_size = std::min( + std::max(group_size, static_cast(2 * config::warp_size)), + static_cast(max_group_size)); + + // reserve 3 for intermediate rho, norms, + // alpha, and for reduce_over_group + // If the value available is negative, then set it to 0 + const int static_var_mem = + (group_size + 3) * sizeof(ValueType) + 2 * sizeof(real_type); + int shmem_per_blk = std::max( + static_cast( + device.get_info()) - + static_var_mem, + 0); + const int padded_num_rows = num_rows; + const size_type prec_size = PrecType::dynamic_work_size( + padded_num_rows, mat.get_single_item_num_nnz()); + const auto sconf = + gko::kernels::batch_cg::compute_shared_storage( + shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), + b.num_rhs); + const size_t shared_size = sconf.n_shared * padded_num_rows + + (sconf.prec_shared ? prec_size : 0); + auto workspace = gko::array( + exec_, + sconf.gmem_stride_bytes * num_batch_items / sizeof(ValueType)); + GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(ValueType) == 0); + + ValueType* const workspace_data = workspace.get_data(); + int n_shared_total = sconf.n_shared + int(sconf.prec_shared); + + // template + // launch_apply_kernel + if (num_rows <= 32 && n_shared_total == 6) + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, workspace_data, + group_size, shared_size); + else if (num_rows <= 256 && n_shared_total == 6) + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, workspace_data, + group_size, shared_size); + else { + switch (n_shared_total) { + case 0: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 1: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 2: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 3: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 4: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 5: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 6: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + default: + GKO_NOT_IMPLEMENTED; + } + } + } + +private: + std::shared_ptr exec_; + const settings> settings_; +}; + + template void apply(std::shared_ptr exec, const settings>& settings, @@ -58,7 +231,9 @@ void apply(std::shared_ptr exec, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) { - GKO_NOT_IMPLEMENTED; + auto dispatcher = batch_solver::create_dispatcher( + kernel_caller(exec, settings), settings, mat, prec); + dispatcher.apply(b, x, logdata); } diff --git a/dpcpp/solver/batch_cg_kernels.hpp.inc b/dpcpp/solver/batch_cg_kernels.hpp.inc new file mode 100644 index 00000000000..e85b591d24c --- /dev/null +++ b/dpcpp/solver/batch_cg_kernels.hpp.inc @@ -0,0 +1,252 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +template +__dpct_inline__ void initialize( + const int num_rows, const BatchMatrixType& mat_global_entry, + const ValueType* const __restrict__ b_global_entry, + const ValueType* const __restrict__ x_global_entry, + ValueType* const __restrict__ x_shared_entry, + ValueType* const __restrict__ r_shared_entry, const PrecType& prec_shared, + ValueType* const __restrict__ z_shared_entry, ValueType& rho_old, + ValueType* const __restrict__ p_shared_entry, + gko::remove_complex& rhs_norms, sycl::nd_item<3> item_ct1) +{ + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); + const auto group_size = item_ct1.get_local_range().size(); + + // copy x from global to shared memory + // r = b + for (int iz = tid; iz < num_rows; iz += group_size) { + x_shared_entry[iz] = x_global_entry[iz]; + r_shared_entry[iz] = b_global_entry[iz]; + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // r = b - A*x + advanced_apply_kernel(static_cast(-1.0), mat_global_entry, + x_shared_entry, static_cast(1.0), + r_shared_entry, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + + // z = precond * r + prec_shared.apply(num_rows, r_shared_entry, z_shared_entry, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // Compute norms of rhs + // and rho_old = r' * z + if (sg_id == 0) { + single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norms, + item_ct1); + } else if (sg_id == 1) { + single_rhs_compute_conj_dot_sg(num_rows, r_shared_entry, z_shared_entry, + rho_old, item_ct1); + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // p = z + for (int iz = tid; iz < num_rows; iz += group_size) { + p_shared_entry[iz] = z_shared_entry[iz]; + } +} + + +template +__dpct_inline__ void update_p( + const int num_rows, const ValueType& rho_new_shared_entry, + const ValueType& rho_old_shared_entry, + const ValueType* const __restrict__ z_shared_entry, + ValueType* const __restrict__ p_shared_entry, sycl::nd_item<3> item_ct1) +{ + const ValueType beta = rho_new_shared_entry / rho_old_shared_entry; + for (int li = item_ct1.get_local_linear_id(); li < num_rows; + li += item_ct1.get_local_range().size()) { + p_shared_entry[li] = z_shared_entry[li] + beta * p_shared_entry[li]; + } +} + +template +__dpct_inline__ void update_x_and_r( + const int num_rows, const ValueType rho_old_shared_entry, + const ValueType* const __restrict__ p_shared_entry, + const ValueType* const __restrict__ Ap_shared_entry, + ValueType& alpha_shared_entry, ValueType* const __restrict__ x_shared_entry, + ValueType* const __restrict__ r_shared_entry, sycl::nd_item<3> item_ct1) +{ + auto sg = item_ct1.get_sub_group(); + const auto tid = item_ct1.get_local_linear_id(); + if (sg.get_group_id() == 0) { + single_rhs_compute_conj_dot_sg(num_rows, p_shared_entry, + Ap_shared_entry, alpha_shared_entry, + item_ct1); + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + if (tid == 0) { + alpha_shared_entry = rho_old_shared_entry / alpha_shared_entry; + } + + for (int li = item_ct1.get_local_linear_id(); li < num_rows; + li += item_ct1.get_local_range().size()) { + x_shared_entry[li] += alpha_shared_entry * p_shared_entry[li]; + r_shared_entry[li] -= alpha_shared_entry * Ap_shared_entry[li]; + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); +} + + +template +__dpct_inline__ void apply_kernel( + const gko::kernels::batch_cg::storage_config sconf, const int max_iter, + const gko::remove_complex tol, LogType logger, + PrecType prec_shared, const BatchMatrixType& mat_global_entry, + const ValueType* const __restrict__ b_global_entry, + ValueType* const __restrict__ x_global_entry, const size_type nrows, + const size_type nnz, ValueType* const __restrict__ slm_values, + sycl::nd_item<3> item_ct1, + ValueType* const __restrict__ workspace = nullptr) +{ + using real_type = typename gko::remove_complex; + + const auto sg = item_ct1.get_sub_group(); + const int sg_id = sg.get_group_id(); + const int sg_size = sg.get_local_range().size(); + const int num_sg = sg.get_group_range().size(); + + const auto group = item_ct1.get_group(); + const auto batch_id = item_ct1.get_group_linear_id(); + + // The whole workgroup have the same values for these variables, but + // these variables are stored in reg. mem, not on SLM + ValueType* rho_old_sh; + ValueType* rho_new_sh; + ValueType* alpha_sh; + real_type* norms_rhs_sh; + real_type* norms_res_sh; + using tile_value_t = ValueType[3]; + tile_value_t& values = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + group); + using tile_real_t = real_type[2]; + tile_real_t& reals = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + group); + rho_old_sh = &values[0]; + rho_new_sh = &values[1]; + alpha_sh = &values[2]; + norms_rhs_sh = &reals[0]; + norms_res_sh = &reals[1]; + const int gmem_offset = + batch_id * sconf.gmem_stride_bytes / sizeof(ValueType); + ValueType* r_sh; + ValueType* z_sh; + ValueType* p_sh; + ValueType* Ap_sh; + ValueType* x_sh; + ValueType* prec_work_sh; + + if constexpr (n_shared_total >= 1) { + r_sh = slm_values; + } else { + r_sh = workspace + gmem_offset; + } + if constexpr (n_shared_total == 1) { + z_sh = workspace + gmem_offset; + } else { + z_sh = r_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 2) { + p_sh = workspace + gmem_offset; + } else { + p_sh = z_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 3) { + Ap_sh = workspace + gmem_offset; + } else { + Ap_sh = p_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 4) { + x_sh = workspace + gmem_offset; + } else { + x_sh = Ap_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 5) { + prec_work_sh = workspace + gmem_offset; + } else { + prec_work_sh = x_sh + sconf.padded_vec_len; + } + + // generate preconditioner + prec_shared.generate(batch_id, mat_global_entry, prec_work_sh, item_ct1); + + // initialization + // compute b norms + // r = b - A*x + // z = precond*r + // rho_old = r' * z (' is for hermitian transpose) + // p = z + initialize(nrows, mat_global_entry, b_global_entry, x_global_entry, x_sh, + r_sh, prec_shared, z_sh, rho_old_sh[0], p_sh, norms_rhs_sh[0], + item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + int iter = 0; + + // stopping criterion object + StopType stop(tol, norms_rhs_sh); + norms_res_sh[0] = sqrt(abs(rho_old_sh[0])); + if (stop.check_converged(norms_res_sh)) { + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + return; + } + + for (; iter < max_iter; iter++) { + // Ap = A * p + simple_apply_kernel(mat_global_entry, p_sh, Ap_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // alpha = rho_old / (p' * Ap) + // x = x + alpha * p + // r = r - alpha * Ap + update_x_and_r(nrows, rho_old_sh[0], p_sh, Ap_sh, alpha_sh[0], x_sh, + r_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + + // z = precond * r + prec_shared.apply(nrows, r_sh, z_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // rho_new = (r)' * (z) + if (sg_id == 0) { + single_rhs_compute_conj_dot_sg(nrows, r_sh, z_sh, rho_new_sh[0], + item_ct1); + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + if (sg.leader()) { + norms_res_sh[0] = sqrt(abs(rho_new_sh[0])); + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + if (stop.check_converged(norms_res_sh)) { + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + break; + } + + // beta = rho_new / rho_old + // p = z + beta * p + update_p(nrows, rho_new_sh[0], rho_old_sh[0], z_sh, p_sh, item_ct1); + if (sg.leader()) { + rho_old_sh[0] = rho_new_sh[0]; + } + } + + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + + // copy x back to global memory + copy_kernel(nrows, x_sh, x_global_entry, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); +} diff --git a/test/solver/CMakeLists.txt b/test/solver/CMakeLists.txt index b66ddd7cebc..b24e063bc6d 100644 --- a/test/solver/CMakeLists.txt +++ b/test/solver/CMakeLists.txt @@ -1,5 +1,5 @@ ginkgo_create_common_test(batch_bicgstab_kernels) -ginkgo_create_common_test(batch_cg_kernels DISABLE_EXECUTORS dpcpp) +ginkgo_create_common_test(batch_cg_kernels) ginkgo_create_common_test(bicg_kernels) ginkgo_create_common_test(bicgstab_kernels) ginkgo_create_common_test(cb_gmres_kernels)