From ff31a16fd006aa6ac00dbac8c43ddac8e1f102fe Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Thu, 26 Oct 2023 23:05:15 +0200 Subject: [PATCH 01/28] Add cuda batch bicgstab kernels Co-authored-by: Aditya Kashi Co-authored-by: Isha Aggarwal --- .../base/batch_multi_vector_kernels.hpp.inc | 83 +++- .../solver/batch_bicgstab_kernels.hpp.inc | 378 ++++++++++++++++++ core/base/batch_struct.hpp | 9 + core/solver/batch_bicgstab_kernels.hpp | 97 +++++ core/test/utils/batch_helpers.hpp | 4 +- cuda/base/exception.cuh | 56 +++ cuda/base/kernel_config.cuh | 65 +++ cuda/solver/batch_bicgstab_kernels.cu | 231 ++++++++++- test/solver/CMakeLists.txt | 2 +- test/solver/batch_bicgstab_kernels.cpp | 76 ++-- 10 files changed, 959 insertions(+), 42 deletions(-) create mode 100644 common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc create mode 100644 cuda/base/exception.cuh create mode 100644 cuda/base/kernel_config.cuh diff --git a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc index 9f77598ff5a..779e2ab0e68 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc +++ b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc @@ -47,10 +47,15 @@ __device__ __forceinline__ void scale( } template -__global__ -__launch_bounds__(default_block_size, sm_oversubscription) void scale_kernel( - const gko::batch::multi_vector::uniform_batch alpha, - const gko::batch::multi_vector::uniform_batch x, Mapping map) +__global__ __launch_bounds__( + default_block_size, + sm_oversubscription) void scale_kernel(const gko::batch::multi_vector:: + uniform_batch + alpha, + const gko::batch::multi_vector:: + uniform_batch + x, + Mapping map) { for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_items; batch_id += gridDim.x) { @@ -103,6 +108,28 @@ __global__ __launch_bounds__( } +template +__device__ __forceinline__ void single_rhs_compute_dot(Group subgroup, + const int num_rows, + const ValueType* x, + const ValueType* y, + ValueType& result) + +{ + ValueType val = zero(); + for (int r = subgroup.thread_rank(); r < num_rows; r += subgroup.size()) { + val += conj(x[r]) * y[r]; + } + + // subgroup level reduction + val = reduce(subgroup, val, thrust::plus{}); + + if (subgroup.thread_rank() == 0) { + result = val; + } +} + + template __device__ __forceinline__ void gen_one_dot( const gko::batch::multi_vector::batch_item& x, @@ -149,11 +176,11 @@ __device__ __forceinline__ void compute_gen_dot_product( template __global__ -__launch_bounds__(default_block_size, sm_oversubscription) void compute_gen_dot_product_kernel( - const gko::batch::multi_vector::uniform_batch x, - const gko::batch::multi_vector::uniform_batch y, - const gko::batch::multi_vector::uniform_batch result, - Mapping map) + __launch_bounds__(default_block_size, sm_oversubscription) void compute_gen_dot_product_kernel( + const gko::batch::multi_vector::uniform_batch x, + const gko::batch::multi_vector::uniform_batch y, + const gko::batch::multi_vector::uniform_batch result, + Mapping map) { for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_items; batch_id += gridDim.x) { @@ -165,6 +192,27 @@ __launch_bounds__(default_block_size, sm_oversubscription) void compute_gen_dot_ } +template +__device__ __forceinline__ void single_rhs_compute_norm2( + Group subgroup, const int num_rows, const ValueType* x, + remove_complex& result) +{ + using real_type = typename gko::remove_complex; + real_type val = zero(); + + for (int r = subgroup.thread_rank(); r < num_rows; r += subgroup.size()) { + val += squared_norm(x[r]); + } + + // subgroup level reduction + val = reduce(subgroup, val, thrust::plus>{}); + + if (subgroup.thread_rank() == 0) { + result = sqrt(val); + } +} + + template __device__ __forceinline__ void one_norm2( const gko::batch::multi_vector::batch_item& x, @@ -238,6 +286,17 @@ __global__ __launch_bounds__( } +template +__device__ __forceinline__ void single_rhs_copy(const int num_rows, + const ValueType* in, + ValueType* out) +{ + for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) { + out[iz] = in[iz]; + } +} + + /** * Copies the values of one multi-vector into another. * @@ -260,9 +319,9 @@ __device__ __forceinline__ void copy( template __global__ -__launch_bounds__(default_block_size, sm_oversubscription) void copy_kernel( - const gko::batch::multi_vector::uniform_batch src, - const gko::batch::multi_vector::uniform_batch dst) + __launch_bounds__(default_block_size, sm_oversubscription) void copy_kernel( + const gko::batch::multi_vector::uniform_batch src, + const gko::batch::multi_vector::uniform_batch dst) { for (size_type batch_id = blockIdx.x; batch_id < src.num_batch_items; batch_id += gridDim.x) { diff --git a/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc new file mode 100644 index 00000000000..a4a57d99f01 --- /dev/null +++ b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc @@ -0,0 +1,378 @@ +/************************************************************* +Copyright (c) 2017-2023, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + + +template +__device__ __forceinline__ void initialize( + Group subgroup, const int num_rows, const BatchMatrixType_entry& mat_entry, + const ValueType* const b_global_entry, + const ValueType* const x_global_entry, ValueType& rho_old, ValueType& omega, + ValueType& alpha, ValueType* const x_shared_entry, + ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry, + ValueType* const p_shared_entry, ValueType* const v_shared_entry, + typename gko::remove_complex& rhs_norm, + typename gko::remove_complex& res_norm) +{ + rho_old = one(); + omega = one(); + alpha = one(); + + // copy x from global to shared memory + // r = b + for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) { + x_shared_entry[iz] = x_global_entry[iz]; + r_shared_entry[iz] = b_global_entry[iz]; + } + __syncthreads(); + + // r = b - A*x + advanced_apply(static_cast(-1.0), mat_entry, x_shared_entry, + static_cast(1.0), r_shared_entry); + __syncthreads(); + + if (threadIdx.x / config::warp_size == 0) { + single_rhs_compute_norm2(subgroup, num_rows, r_shared_entry, res_norm); + } else if (threadIdx.x / config::warp_size == 1) { + // Compute norms of rhs + single_rhs_compute_norm2(subgroup, num_rows, b_global_entry, rhs_norm); + } + __syncthreads(); + + for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) { + r_hat_shared_entry[iz] = r_shared_entry[iz]; + p_shared_entry[iz] = zero(); + v_shared_entry[iz] = zero(); + } +} + + +template +__device__ __forceinline__ void update_p( + const int num_rows, const ValueType& rho_new, const ValueType& rho_old, + const ValueType& alpha, const ValueType& omega, + const ValueType* const r_shared_entry, + const ValueType* const v_shared_entry, ValueType* const p_shared_entry) +{ + for (int r = threadIdx.x; r < num_rows; r += blockDim.x) { + const ValueType beta = (rho_new / rho_old) * (alpha / omega); + p_shared_entry[r] = + r_shared_entry[r] + + beta * (p_shared_entry[r] - omega * v_shared_entry[r]); + } +} + +template +__device__ __forceinline__ void compute_alpha( + Group subgroup, const int num_rows, const ValueType& rho_new, + const ValueType* const r_hat_shared_entry, + const ValueType* const v_shared_entry, ValueType& alpha) +{ + if (threadIdx.x / config::warp_size == 0) { + single_rhs_compute_dot(subgroup, num_rows, r_hat_shared_entry, + v_shared_entry, alpha); + } + __syncthreads(); + if (threadIdx.x == 0) { + alpha = rho_new / alpha; + } +} + + +template +__device__ __forceinline__ void update_s(const int num_rows, + const ValueType* const r_shared_entry, + const ValueType& alpha, + const ValueType* const v_shared_entry, + ValueType* const s_shared_entry) +{ + for (int r = threadIdx.x; r < num_rows; r += blockDim.x) { + s_shared_entry[r] = r_shared_entry[r] - alpha * v_shared_entry[r]; + } +} + + +template +__device__ __forceinline__ void compute_omega( + Group subgroup, const int num_rows, const ValueType* const t_shared_entry, + const ValueType* const s_shared_entry, ValueType& temp, ValueType& omega) +{ + if (threadIdx.x / config::warp_size == 0) { + single_rhs_compute_dot(subgroup, num_rows, t_shared_entry, + s_shared_entry, omega); + } else if (threadIdx.x / config::warp_size == 1) { + single_rhs_compute_dot(subgroup, num_rows, t_shared_entry, + t_shared_entry, temp); + } + + __syncthreads(); + if (threadIdx.x == 0) { + omega /= temp; + } +} + +template +__device__ __forceinline__ void update_x_and_r( + const int num_rows, const ValueType* const p_hat_shared_entry, + const ValueType* const s_hat_shared_entry, const ValueType& alpha, + const ValueType& omega, const ValueType* const s_shared_entry, + const ValueType* const t_shared_entry, ValueType* const x_shared_entry, + ValueType* const r_shared_entry) +{ + for (int r = threadIdx.x; r < num_rows; r += blockDim.x) { + x_shared_entry[r] = x_shared_entry[r] + alpha * p_hat_shared_entry[r] + + omega * s_hat_shared_entry[r]; + r_shared_entry[r] = s_shared_entry[r] - omega * t_shared_entry[r]; + } +} + + +template +__device__ __forceinline__ void update_x_middle( + const int num_rows, const ValueType& alpha, + const ValueType* const p_hat_shared_entry, ValueType* const x_shared_entry) +{ + for (int r = threadIdx.x; r < num_rows; r += blockDim.x) { + x_shared_entry[r] = x_shared_entry[r] + alpha * p_hat_shared_entry[r]; + } +} + + +template +__global__ void apply_kernel( + const gko::kernels::batch_bicgstab::storage_config sconf, + const int max_iter, const gko::remove_complex tol, + LogType logger, PrecType prec_shared, const BatchMatrixType mat, + const ValueType* const __restrict__ b, ValueType* const __restrict__ x, + ValueType* const __restrict__ workspace = nullptr) +{ + using real_type = typename gko::remove_complex; + const auto num_batch_items = mat.num_batch_items; + const auto num_rows = mat.num_rows; + + constexpr auto tile_size = config::warp_size; + auto thread_block = group::this_thread_block(); + auto subgroup = group::tiled_partition(thread_block); + + for (int batch_id = blockIdx.x; batch_id < num_batch_items; + batch_id += gridDim.x) { + const int gmem_offset = + batch_id * sconf.gmem_stride_bytes / sizeof(ValueType); + extern __shared__ char local_mem_sh[]; + + ValueType* p_hat_sh; + ValueType* s_hat_sh; + ValueType* p_sh; + ValueType* s_sh; + ValueType* r_sh; + ValueType* r_hat_sh; + ValueType* v_sh; + ValueType* t_sh; + ValueType* x_sh; + ValueType* prec_work_sh; + + if (n_shared >= 1) { + p_hat_sh = reinterpret_cast(local_mem_sh); + } else { + p_hat_sh = workspace + gmem_offset; + } + if (n_shared == 1) { + s_hat_sh = workspace + gmem_offset; + } else { + s_hat_sh = p_hat_sh + sconf.padded_vec_len; + } + if (n_shared == 2) { + v_sh = workspace + gmem_offset; + } else { + v_sh = s_hat_sh + sconf.padded_vec_len; + } + if (n_shared == 3) { + t_sh = workspace + gmem_offset; + } else { + t_sh = v_sh + sconf.padded_vec_len; + } + if (n_shared == 4) { + p_sh = workspace + gmem_offset; + } else { + p_sh = t_sh + sconf.padded_vec_len; + } + if (n_shared == 5) { + s_sh = workspace + gmem_offset; + } else { + s_sh = p_sh + sconf.padded_vec_len; + } + if (n_shared == 6) { + r_sh = workspace + gmem_offset; + } else { + r_sh = s_sh + sconf.padded_vec_len; + } + if (n_shared == 7) { + r_hat_sh = workspace + gmem_offset; + } else { + r_hat_sh = r_sh + sconf.padded_vec_len; + } + if (n_shared == 8) { + x_sh = workspace + gmem_offset; + } else { + x_sh = r_hat_sh + sconf.padded_vec_len; + } + if (!prec_shared_bool && n_shared == 9) { + prec_work_sh = workspace + gmem_offset; + } else { + prec_work_sh = x_sh + sconf.padded_vec_len; + } + + __shared__ uninitialized_array rho_old_sh; + __shared__ uninitialized_array rho_new_sh; + __shared__ uninitialized_array omega_sh; + __shared__ uninitialized_array alpha_sh; + __shared__ uninitialized_array temp_sh; + __shared__ real_type norms_rhs_sh[1]; + __shared__ real_type norms_res_sh[1]; + + const auto mat_entry = + gko::batch::matrix::extract_batch_item(mat, batch_id); + const ValueType* const b_entry_ptr = + gko::batch::multi_vector::batch_item_ptr(b, 1, num_rows, batch_id); + ValueType* const x_gl_entry_ptr = + gko::batch::multi_vector::batch_item_ptr(x, 1, num_rows, batch_id); + + // generate preconditioner + prec_shared.generate(batch_id, mat_entry, prec_work_sh); + + // initialization + // rho_old = 1, omega = 1, alpha = 1 + // compute b norms + // copy x from global to shared memory + // r = b - A*x + // compute residual norms + // r_hat = r + // p = 0 + // v = 0 + initialize(subgroup, num_rows, mat_entry, b_entry_ptr, x_gl_entry_ptr, + rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh, + r_hat_sh, p_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0]); + __syncthreads(); + + // stopping criterion object + StopType stop(tol, norms_rhs_sh); + + int iter = 0; + for (; iter < max_iter; iter++) { + if (stop.check_converged(norms_res_sh)) { + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + break; + } + + // rho_new = < r_hat , r > = (r_hat)' * (r) + if (threadIdx.x / config::warp_size == 0) { + single_rhs_compute_dot(subgroup, num_rows, r_hat_sh, r_sh, + rho_new_sh[0]); + } + __syncthreads(); + + // beta = (rho_new / rho_old)*(alpha / omega) + // p = r + beta*(p - omega * v) + update_p(num_rows, rho_new_sh[0], rho_old_sh[0], alpha_sh[0], + omega_sh[0], r_sh, v_sh, p_sh); + __syncthreads(); + + // p_hat = precond * p + prec_shared.apply(num_rows, p_sh, p_hat_sh); + __syncthreads(); + + // v = A * p_hat + simple_apply(mat_entry, p_hat_sh, v_sh); + __syncthreads(); + + // alpha = rho_new / < r_hat , v> + compute_alpha(subgroup, num_rows, rho_new_sh[0], r_hat_sh, v_sh, + alpha_sh[0] /*, converged*/); + __syncthreads(); + + // s = r - alpha*v + update_s(num_rows, r_sh, alpha_sh[0], v_sh, s_sh /*, converged*/); + __syncthreads(); + + // an estimate of residual norms + if (threadIdx.x / config::warp_size == 0) { + single_rhs_compute_norm2(subgroup, num_rows, s_sh, + norms_res_sh[0]); + } + __syncthreads(); + + // if (norms_res_sh[0] / norms_rhs_sh[0] < tol) { + if (stop.check_converged(norms_res_sh)) { + update_x_middle(num_rows, alpha_sh[0], p_hat_sh, x_sh); + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + break; + } + + // s_hat = precond * s + prec_shared.apply(num_rows, s_sh, s_hat_sh); + __syncthreads(); + + // t = A * s_hat + simple_apply(mat_entry, s_hat_sh, t_sh); + __syncthreads(); + + // omega = / + compute_omega(subgroup, num_rows, t_sh, s_sh, temp_sh[0], + omega_sh[0] /*, converged*/); + __syncthreads(); + + // x = x + alpha*p_hat + omega *s_hat + // r = s - omega * t + update_x_and_r(num_rows, p_hat_sh, s_hat_sh, alpha_sh[0], + omega_sh[0], s_sh, t_sh, x_sh, r_sh /*, converged*/); + __syncthreads(); + + if (threadIdx.x / config::warp_size == 0) { + single_rhs_compute_norm2(subgroup, num_rows, r_sh, + norms_res_sh[0]); + } + //__syncthreads(); + + if (threadIdx.x == blockDim.x - 1) { + rho_old_sh[0] = rho_new_sh[0]; + } + __syncthreads(); + } + + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + + // copy x back to global memory + single_rhs_copy(num_rows, x_sh, x_gl_entry_ptr); + __syncthreads(); + } +} diff --git a/core/base/batch_struct.hpp b/core/base/batch_struct.hpp index 975671739eb..041630af66e 100644 --- a/core/base/batch_struct.hpp +++ b/core/base/batch_struct.hpp @@ -78,6 +78,15 @@ struct uniform_batch { }; +template +GKO_ATTRIBUTES GKO_INLINE ValueType* batch_item_ptr( + ValueType* const batch_start, const size_type stride, const int num_rows, + const size_type batch_idx) +{ + return batch_start + batch_idx * stride * num_rows; +} + + } // namespace multi_vector diff --git a/core/solver/batch_bicgstab_kernels.hpp b/core/solver/batch_bicgstab_kernels.hpp index 4689badeebd..ccde3aa6826 100644 --- a/core/solver/batch_bicgstab_kernels.hpp +++ b/core/solver/batch_bicgstab_kernels.hpp @@ -92,6 +92,103 @@ inline int local_memory_requirement(const int num_rows, const int num_rhs) } +struct storage_config { + // preconditioner storage + bool prec_shared; + // total number of shared vectors + int n_shared; + // number of vectors in global memory + int n_global; + // global stride from one batch entry to the next + int gmem_stride_bytes; + // padded vector length + int padded_vec_len; +}; + + +template +void set_gmem_stride_bytes(storage_config& sconf, + const int multi_vector_size_bytes, + const int prec_storage_bytes) +{ + int gmem_stride = sconf.n_global * multi_vector_size_bytes; + if (!sconf.prec_shared) { + gmem_stride += prec_storage_bytes; + } + // align global memory chunks + sconf.gmem_stride_bytes = + gmem_stride > 0 ? ((gmem_stride - 1) / align_bytes + 1) * align_bytes + : 0; +} + + +/** + * Calculates the amount of in-solver storage needed by batch-Bicgstab and + * the split between shared and global memory. + * + * The calculation includes multivectors for + * - r + * - r_hat + * - p + * - p_hat + * - v + * - s + * - s_hat + * - t + * - x + * In addition, small arrays are needed for + * - rho_old + * - rho_new + * - omega + * - alpha + * - temp + * - rhs_norms + * - res_norms + * + * @param shared_mem_per_blk The amount of shared memory per block to use for + * keeping intermediate vectors. In case keeping the matrix in L1 cache etc. + * should be prioritized, the cache configuration must be updated separately + * and the needed space should be subtracted before passing to this + * function. + * @param num_rows Size of the matrix. + * @param num_nz Number of nonzeros in the matrix + * @param num_rhs Number of right-hand-sides in the vectors. + * @return A struct containing allocation information specific to Bicgstab. + */ +template +storage_config compute_shared_storage(const int shared_mem_per_blk, + const int num_rows, const int num_nz, + const int num_rhs) +{ + using real_type = remove_complex; + const int vec_size = num_rows * num_rhs * sizeof(ValueType); + const int num_main_vecs = 9; + const int prec_storage = + Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType); + int rem_shared = shared_mem_per_blk; + storage_config sconf{false, 0, num_main_vecs, 0, num_rows}; + if (rem_shared <= 0) { + set_gmem_stride_bytes(sconf, vec_size, prec_storage); + return sconf; + } + const int initial_vecs_available = rem_shared / vec_size; + const int num_vecs_shared = min(initial_vecs_available, num_main_vecs); + sconf.n_shared += num_vecs_shared; + sconf.n_global -= num_vecs_shared; + if (sconf.n_global > 0) { + set_gmem_stride_bytes(sconf, vec_size, prec_storage); + return sconf; + } + rem_shared -= num_vecs_shared * vec_size; + if (rem_shared >= prec_storage && prec_storage > 0) { + sconf.prec_shared = true; + rem_shared -= prec_storage; + } + set_gmem_stride_bytes(sconf, vec_size, prec_storage); + return sconf; +} + + } // namespace batch_bicgstab diff --git a/core/test/utils/batch_helpers.hpp b/core/test/utils/batch_helpers.hpp index 77c2d397889..0a6702ff42f 100644 --- a/core/test/utils/batch_helpers.hpp +++ b/core/test/utils/batch_helpers.hpp @@ -250,7 +250,7 @@ LinearSystem generate_batch_linear_system( // A * x_{exact} = b sys.matrix->apply(sys.exact_sol, sys.rhs); const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs)); - sys.rhs_norm = real_vec::create(exec, norm_dim); + sys.rhs_norm = real_vec::create(exec->get_master(), norm_dim); sys.rhs->compute_norm2(sys.rhs_norm.get()); return sys; } @@ -273,7 +273,7 @@ compute_residual_norms( const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs)); auto residual_vec = b->clone(); - auto res_norms = real_vec::create(exec, norm_dim); + auto res_norms = real_vec::create(exec->get_master(), norm_dim); auto alpha = gko::batch::initialize(num_batch_items, {-1.0}, exec); auto beta = gko::batch::initialize(num_batch_items, {1.0}, exec); diff --git a/cuda/base/exception.cuh b/cuda/base/exception.cuh new file mode 100644 index 00000000000..51dfb63bf72 --- /dev/null +++ b/cuda/base/exception.cuh @@ -0,0 +1,56 @@ +/************************************************************* +Copyright (c) 2017-2023, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#ifndef GKO_CUDA_BASE_EXCEPTION_CUH_ +#define GKO_CUDA_BASE_EXCEPTION_CUH_ + + +#include + + +namespace gko { + + +#define GKO_CUDA_LAST_IF_ERROR_THROW \ + cudaError_t err = cudaGetLastError(); \ + if (err != cudaSuccess) { \ + printf(" Kernel error: %s\n", cudaGetErrorString(err)); \ + throw gko::CudaError(__FILE__, __LINE__, __func__, err); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + + +} // namespace gko + +#endif // GKO_CUDA_BASE_EXCEPTION_CUH_ diff --git a/cuda/base/kernel_config.cuh b/cuda/base/kernel_config.cuh new file mode 100644 index 00000000000..6280753bcda --- /dev/null +++ b/cuda/base/kernel_config.cuh @@ -0,0 +1,65 @@ +/************************************************************* +Copyright (c) 2017-2023, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#ifndef GKO_CUDA_BASE_KERNEL_CONFIG_CUH_ +#define GKO_CUDA_BASE_KERNEL_CONFIG_CUH_ + + +#include "cuda/base/math.hpp" + + +namespace gko { +namespace kernels { +namespace cuda { + + +/** + * Set shared memory bank configuration. + * + * \tparam ValueType The scalar type used for computations. + */ +template +inline void configure_shared_memory_banks() +{ + if (sizeof(ValueType) == 4) { + cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeFourByte); + } else if (sizeof(ValueType) % 8 == 0) { + cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte); + } +} + + +} // namespace cuda +} // namespace kernels +} // namespace gko + +#endif // GKO_CUDA_BASE_KERNEL_CONFIG_CUH_ diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index ee7d0948b99..db92543fd74 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -33,21 +33,40 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/solver/batch_bicgstab_kernels.hpp" +#include +#include + + #include #include +#include "core/base/batch_struct.hpp" +#include "core/matrix/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" +#include "cuda/base/batch_struct.hpp" #include "cuda/base/config.hpp" +#include "cuda/base/exception.cuh" +#include "cuda/base/kernel_config.cuh" +#include "cuda/base/thrust.cuh" #include "cuda/base/types.hpp" #include "cuda/components/cooperative_groups.cuh" +#include "cuda/components/reduction.cuh" #include "cuda/components/thread_ids.cuh" +#include "cuda/components/uninitialized_array.hpp" #include "cuda/matrix/batch_struct.hpp" namespace gko { namespace kernels { namespace cuda { + + +// NOTE: this default block size is not used for the main solver kernel. +constexpr int default_block_size = 256; +constexpr int sm_oversubscription = 4; + + /** * @brief The batch Bicgstab solver namespace. * @@ -56,19 +75,227 @@ namespace cuda { namespace batch_bicgstab { +#include "common/cuda_hip/components/uninitialized_array.hpp.inc" + +#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc" +#include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc" +#include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc" +#include "common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc" + + +template +int get_num_threads_per_block(std::shared_ptr exec, + const int num_rows) +{ + int nwarps = num_rows / 4; + if (nwarps < 2) { + nwarps = 2; + } + const int min_block_size = 2 * config::warp_size; + const int device_max_threads = + ((std::max(num_rows, min_block_size)) / config::warp_size) * + config::warp_size; + cudaFuncAttributes funcattr; + cudaFuncGetAttributes(&funcattr, + apply_kernel); + const int num_regs_used = funcattr.numRegs; + int max_regs_blk = 0; + cudaDeviceGetAttribute(&max_regs_blk, cudaDevAttrMaxRegistersPerBlock, + exec->get_device_id()); + const int max_threads_regs = + ((max_regs_blk / + static_cast((static_cast(num_regs_used)))) / + config::warp_size) * + config::warp_size; + int max_threads = std::min(max_threads_regs, device_max_threads); + max_threads = max_threads <= 1024 ? max_threads : 1024; + return std::min(nwarps * static_cast(config::warp_size), max_threads); +} + + +template +int get_max_dynamic_shared_memory(std::shared_ptr exec, + const size_type required_cache_storage) +{ + int shmem_per_sm = 0; + cudaDeviceGetAttribute(&shmem_per_sm, + cudaDevAttrMaxSharedMemoryPerMultiprocessor, + exec->get_device_id()); + GKO_ASSERT_NO_CUDA_ERRORS(cudaFuncSetAttribute( + apply_kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, 99 /*%*/)); + cudaFuncAttributes funcattr; + cudaFuncGetAttributes(&funcattr, + apply_kernel); + return funcattr.maxDynamicSharedSizeBytes; +} + + template using settings = gko::kernels::batch_bicgstab::settings; +template +class KernelCaller { +public: + using value_type = CuValueType; + + KernelCaller(std::shared_ptr exec, + const settings> settings) + : exec_{exec}, settings_{settings} + {} + + template + void launch_apply_kernel( + const gko::kernels::batch_bicgstab::storage_config& sconf, + LogType& logger, PrecType& prec, const BatchMatrixType& mat, + const value_type* const __restrict__ b_values, + value_type* const __restrict__ x_values, + value_type* const __restrict__ workspace_data, const int& block_size, + const size_t& shared_size) const + { + apply_kernel + <<get_stream()>>>(sconf, settings_.max_iterations, + settings_.residual_tol, logger, prec, mat, + b_values, x_values, workspace_data); + } + + + template + void call_kernel( + LogType logger, const BatchMatrixType& mat, PrecType prec, + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const + { + using real_type = gko::remove_complex; + const size_type num_batch_items = mat.num_batch_items; + constexpr int align_multiple = 2; + const int shared_gap = + ((mat.num_rows + align_multiple - 1) / align_multiple) * + align_multiple; + gko::kernels::cuda::configure_shared_memory_banks(); + const int shmem_per_blk = + get_max_dynamic_shared_memory(exec_, + 0); + const int block_size = + get_num_threads_per_block( + exec_, mat.num_rows); + assert(block_size >= 2 * config::warp_size); + + const size_t prec_size = + PrecType::dynamic_work_size(shared_gap, + mat.get_single_item_num_nnz()) * + sizeof(value_type); + const auto sconf = + gko::kernels::batch_bicgstab::compute_shared_storage( + shmem_per_blk, shared_gap, mat.get_single_item_num_nnz(), + b.num_rhs); + const size_t shared_size = + sconf.n_shared * shared_gap * sizeof(value_type) + + (sconf.prec_shared ? prec_size : 0); + auto workspace = gko::array( + exec_, + sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type)); + assert(sconf.gmem_stride_bytes % sizeof(value_type) == 0); + + value_type* const workspace_data = workspace.get_data(); + + // Template parameters launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, workspace_data, + block_size, shared_size); + else { + switch (sconf.n_shared) { + case 0: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 1: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 2: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 3: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 4: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 5: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 6: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 7: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 8: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 9: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + } + } + + GKO_CUDA_LAST_IF_ERROR_THROW; + } + +private: + std::shared_ptr exec_; + const settings> settings_; +}; + + template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* const a, + const batch::BatchLinOp* const mat, const batch::BatchLinOp* const precon, const batch::MultiVector* const b, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) - GKO_NOT_IMPLEMENTED; +{ + using cu_value_type = cuda_type; + auto dispatcher = batch::solver::create_dispatcher( + KernelCaller(exec, settings), settings, mat, precon); + dispatcher.apply(b, x, logdata); +} GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); diff --git a/test/solver/CMakeLists.txt b/test/solver/CMakeLists.txt index 296a55b6271..28a217a79fc 100644 --- a/test/solver/CMakeLists.txt +++ b/test/solver/CMakeLists.txt @@ -1,4 +1,4 @@ -ginkgo_create_common_test(batch_bicgstab_kernels DISABLE_EXECUTORS dpcpp cuda hip) +ginkgo_create_common_test(batch_bicgstab_kernels DISABLE_EXECUTORS dpcpp hip) ginkgo_create_common_test(bicg_kernels) ginkgo_create_common_test(bicgstab_kernels) ginkgo_create_common_test(cb_gmres_kernels) diff --git a/test/solver/batch_bicgstab_kernels.cpp b/test/solver/batch_bicgstab_kernels.cpp index adb68d92314..124dd27640c 100644 --- a/test/solver/batch_bicgstab_kernels.cpp +++ b/test/solver/batch_bicgstab_kernels.cpp @@ -171,7 +171,7 @@ TEST_F(BatchBicgstab, StencilSystemLoggerLogsIterations) TEST_F(BatchBicgstab, CanSolve3ptStencilSystem) { - const int num_batch_items = 12; + const int num_batch_items = 8; const int num_rows = 100; const int num_rhs = 1; const real_type tol = 1e-5; @@ -185,35 +185,59 @@ TEST_F(BatchBicgstab, CanSolve3ptStencilSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); for (size_t i = 0; i < num_batch_items; i++) { - auto comp_res_norm = - exec->copy_val_to_host(res.res_norm->get_const_values() + i) / - exec->copy_val_to_host(linear_system.rhs_norm->get_const_values() + - i); + auto comp_res_norm = res.res_norm->get_const_values()[i] / + linear_system.rhs_norm->get_const_values()[i]; ASSERT_LE(comp_res_norm, tol); } } -TEST_F(BatchBicgstab, CanSolveLargeHpdSystem) +TEST_F(BatchBicgstab, CanSolveLargeBatchSizeHpdSystem) { - const int num_batch_items = 3; + const int num_batch_items = 100; + const int num_rows = 102; + const int num_rhs = 1; + const real_type tol = 1e-5; + const int max_iters = num_rows; + std::shared_ptr logger = Logger::create(); + auto mat = gko::share(gko::test::generate_diag_dominant_batch_matrix( + exec, num_batch_items, num_rows, true)); + auto linear_system = setup_linsys_and_solver(mat, num_rhs, tol, max_iters); + auto solver = gko::share(solver_factory->generate(linear_system.matrix)); + solver->add_logger(logger); + + auto res = gko::test::solve_linear_system(exec, linear_system, solver); + + solver->remove_logger(logger); + auto iter_counts = gko::make_temporary_clone(exec->get_master(), + &logger->get_num_iterations()); + auto res_norm = gko::make_temporary_clone(exec->get_master(), + &logger->get_residual_norm()); + GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); + for (size_t i = 0; i < num_batch_items; i++) { + auto comp_res_norm = res.res_norm->get_const_values()[i] / + linear_system.rhs_norm->get_const_values()[i]; + ASSERT_LE(iter_counts->get_const_data()[i], max_iters); + EXPECT_LE(res_norm->get_const_data()[i] / + linear_system.rhs_norm->get_const_values()[i], + tol); + EXPECT_GT(res_norm->get_const_data()[i], real_type{0.0}); + ASSERT_LE(comp_res_norm, tol); + } +} + + +TEST_F(BatchBicgstab, CanSolveLargeMatrixSizeHpdSystem) +{ + const int num_batch_items = 12; const int num_rows = 1025; const int num_rhs = 1; const real_type tol = 1e-5; - const int max_iters = 2000; - const real_type comp_tol = tol * 100; - auto solver_factory = - solver_type::build() - .with_max_iterations(max_iters) - .with_tolerance(tol) - .with_tolerance_type(gko::batch::stop::tolerance_type::absolute) - .on(exec); + const int max_iters = num_rows; std::shared_ptr logger = Logger::create(); - auto diag_dom_mat = - gko::share(gko::test::generate_diag_dominant_batch_matrix( - exec, num_batch_items, num_rows, true)); - auto linear_system = - gko::test::generate_batch_linear_system(diag_dom_mat, num_rhs); + auto mat = gko::share(gko::test::generate_diag_dominant_batch_matrix( + exec, num_batch_items, num_rows, true)); + auto linear_system = setup_linsys_and_solver(mat, num_rhs, tol, max_iters); auto solver = gko::share(solver_factory->generate(linear_system.matrix)); solver->add_logger(logger); @@ -224,13 +248,15 @@ TEST_F(BatchBicgstab, CanSolveLargeHpdSystem) &logger->get_num_iterations()); auto res_norm = gko::make_temporary_clone(exec->get_master(), &logger->get_residual_norm()); - GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, comp_tol); + GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); for (size_t i = 0; i < num_batch_items; i++) { - auto comp_res_norm = - exec->copy_val_to_host(res.res_norm->get_const_values() + i); + auto comp_res_norm = res.res_norm->get_const_values()[i] / + linear_system.rhs_norm->get_const_values()[i]; ASSERT_LE(iter_counts->get_const_data()[i], max_iters); - EXPECT_LE(res_norm->get_const_data()[i], comp_tol); + EXPECT_LE(res_norm->get_const_data()[i] / + linear_system.rhs_norm->get_const_values()[i], + tol); EXPECT_GT(res_norm->get_const_data()[i], real_type{0.0}); - ASSERT_LE(comp_res_norm, comp_tol); + ASSERT_LE(comp_res_norm, tol); } } From b357c6b2be4f91ca84fc45d31d18a5b64f4539cc Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Fri, 27 Oct 2023 00:06:11 +0200 Subject: [PATCH 02/28] Add hip bicgstab solver kernels Co-authored-by: Aditya Kashi Co-authored-by: Isha Aggarwal --- cuda/base/executor.cpp | 3 + cuda/solver/batch_bicgstab_kernels.cu | 6 +- hip/base/exception.hip.hpp | 56 +++++++ hip/base/executor.hip.cpp | 3 + hip/solver/batch_bicgstab_kernels.hip.cpp | 192 +++++++++++++++++++++- test/solver/CMakeLists.txt | 2 +- 6 files changed, 256 insertions(+), 6 deletions(-) create mode 100644 hip/base/exception.hip.hpp diff --git a/cuda/base/executor.cpp b/cuda/base/executor.cpp index f296fb9da86..01880127641 100644 --- a/cuda/base/executor.cpp +++ b/cuda/base/executor.cpp @@ -258,6 +258,9 @@ void CudaExecutor::set_gpu_property() kernels::cuda::config::warp_size; this->get_exec_info().max_subgroup_size = kernels::cuda::config::warp_size; + GKO_ASSERT_NO_CUDA_ERRORS(cudaDeviceGetAttribute( + &this->get_exec_info().max_shared_memory_per_workgroup, + cudaDevAttrMaxSharedMemoryPerBlock, this->get_device_id())); } } diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index db92543fd74..07e16535631 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -85,7 +85,7 @@ namespace batch_bicgstab { template -int get_num_threads_per_block(std::shared_ptr exec, +int get_num_threads_per_block(std::shared_ptr exec, const int num_rows) { int nwarps = num_rows / 4; @@ -117,7 +117,7 @@ int get_num_threads_per_block(std::shared_ptr exec, template -int get_max_dynamic_shared_memory(std::shared_ptr exec, +int get_max_dynamic_shared_memory(std::shared_ptr exec, const size_type required_cache_storage) { int shmem_per_sm = 0; @@ -178,7 +178,7 @@ public: { using real_type = gko::remove_complex; const size_type num_batch_items = mat.num_batch_items; - constexpr int align_multiple = 2; + constexpr int align_multiple = 8; const int shared_gap = ((mat.num_rows + align_multiple - 1) / align_multiple) * align_multiple; diff --git a/hip/base/exception.hip.hpp b/hip/base/exception.hip.hpp new file mode 100644 index 00000000000..7c3b3b2e12e --- /dev/null +++ b/hip/base/exception.hip.hpp @@ -0,0 +1,56 @@ +/************************************************************* +Copyright (c) 2017-2023, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#ifndef GKO_HIP_BASE_EXCEPTION_HIP_HPP_ +#define GKO_HIP_BASE_EXCEPTION_HIP_HPP_ + + +#include + + +namespace gko { + + +#define GKO_HIP_LAST_IF_ERROR_THROW \ + hipError_t err = hipGetLastError(); \ + if (err != hipSuccess) { \ + printf(" Hip kernel error: %s\n", hipGetErrorString(err)); \ + throw gko::HipError(__FILE__, __LINE__, __func__, err); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + + +} // namespace gko + +#endif // GKO_HIP_BASE_EXCEPTION_HIP_HPP_ diff --git a/hip/base/executor.hip.cpp b/hip/base/executor.hip.cpp index 8d175c0e424..489e9b28ff9 100644 --- a/hip/base/executor.hip.cpp +++ b/hip/base/executor.hip.cpp @@ -262,6 +262,9 @@ void HipExecutor::set_gpu_property() #endif // GINKGO_HIP_PLATFORM_NVCC this->get_exec_info().max_subgroup_size = kernels::hip::config::warp_size; + GKO_ASSERT_NO_HIP_ERRORS(hipDeviceGetAttribute( + &this->get_exec_info().max_shared_memory_per_workgroup, + hipDeviceAttributeMaxSharedMemoryPerBlock, this->get_device_id())); } } diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index 4ef8cd36c1b..b9fe8b0c9c3 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -34,21 +34,38 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include +#include #include #include +#include "core/base/batch_struct.hpp" +#include "core/matrix/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" #include "hip/base/batch_struct.hip.hpp" #include "hip/base/config.hip.hpp" +#include "hip/base/exception.hip.hpp" +#include "hip/base/math.hip.hpp" +#include "hip/base/thrust.hip.hpp" +#include "hip/base/types.hip.hpp" +#include "hip/components/cooperative_groups.hip.hpp" +#include "hip/components/reduction.hip.hpp" +#include "hip/components/thread_ids.hip.hpp" +#include "hip/components/uninitialized_array.hip.hpp" #include "hip/matrix/batch_struct.hip.hpp" namespace gko { namespace kernels { namespace hip { + + +constexpr int default_block_size = 256; +constexpr int sm_oversubscription = 4; + /** * @brief The batch Bicgstab solver namespace. * @@ -57,19 +74,190 @@ namespace hip { namespace batch_bicgstab { +#include "common/cuda_hip/components/uninitialized_array.hpp.inc" + +#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc" +#include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc" +#include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc" +#include "common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc" + + +template +int get_num_threads_per_block(std::shared_ptr exec, + const int num_rows) +{ + int nwarps = num_rows / 4; + if (nwarps < 2) { + nwarps = 2; + } + const int min_block_size = 2 * config::warp_size; + const int device_max_threads = + ((std::max(num_rows, min_block_size)) / config::warp_size) * + config::warp_size; + const int num_regs_used_per_thread = 64; + int max_regs_blk = 0; + hipDeviceGetAttribute(&max_regs_blk, hipDeviceAttributeMaxRegistersPerBlock, + exec->get_device_id()); + const int max_threads_regs = (max_regs_blk / num_regs_used_per_thread); + const int max_threads = std::min(max_threads_regs, device_max_threads); + return std::min(nwarps * static_cast(config::warp_size), max_threads); +} + + template using settings = gko::kernels::batch_bicgstab::settings; +template +class KernelCaller { +public: + using value_type = HipValueType; + + KernelCaller(std::shared_ptr exec, + const settings> settings) + : exec_{exec}, settings_{settings} + {} + + template + void launch_apply_kernel( + const gko::kernels::batch_bicgstab::storage_config& sconf, + LogType& logger, PrecType& prec, const BatchMatrixType& mat, + const value_type* const __restrict__ b_values, + value_type* const __restrict__ x_values, + value_type* const __restrict__ workspace_data, const int& block_size, + const size_t& shared_size) const + { + apply_kernel + <<get_stream()>>>(sconf, settings_.max_iterations, + settings_.residual_tol, logger, prec, mat, + b_values, x_values, workspace_data); + } + + + template + void call_kernel( + LogType logger, const BatchMatrixType& mat, PrecType prec, + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const + { + using real_type = gko::remove_complex; + const size_type num_batch_items = mat.num_batch_items; + constexpr int align_multiple = 8; + const int shared_gap = + ((mat.num_rows + align_multiple - 1) / align_multiple) * + align_multiple; + const int shmem_per_blk = exec_->get_max_shared_memory_per_block(); + const int block_size = + get_num_threads_per_block(exec_, mat.num_rows); + assert(block_size >= 2 * config::warp_size); + + const size_t prec_size = + PrecType::dynamic_work_size(shared_gap, + mat.get_single_item_num_nnz()) * + sizeof(value_type); + const auto sconf = + gko::kernels::batch_bicgstab::compute_shared_storage( + shmem_per_blk, shared_gap, mat.get_single_item_num_nnz(), + b.num_rhs); + const size_t shared_size = + sconf.n_shared * shared_gap * sizeof(value_type) + + (sconf.prec_shared ? prec_size : 0); + auto workspace = gko::array( + exec_, + sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type)); + assert(sconf.gmem_stride_bytes % sizeof(value_type) == 0); + + value_type* const workspace_data = workspace.get_data(); + + // Template parameters launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, workspace_data, + block_size, shared_size); + else { + switch (sconf.n_shared) { + case 0: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 1: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 2: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 3: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 4: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 5: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 6: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 7: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 8: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 9: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + } + } + + GKO_HIP_LAST_IF_ERROR_THROW; + } + +private: + std::shared_ptr exec_; + const settings> settings_; +}; + + template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* const a, + const batch::BatchLinOp* const mat, const batch::BatchLinOp* const precon, const batch::MultiVector* const b, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) - GKO_NOT_IMPLEMENTED; +{ + using hip_value_type = hip_type; + auto dispatcher = batch::solver::create_dispatcher( + KernelCaller(exec, settings), settings, mat, precon); + dispatcher.apply(b, x, logdata); +} GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); diff --git a/test/solver/CMakeLists.txt b/test/solver/CMakeLists.txt index 28a217a79fc..de3430393ae 100644 --- a/test/solver/CMakeLists.txt +++ b/test/solver/CMakeLists.txt @@ -1,4 +1,4 @@ -ginkgo_create_common_test(batch_bicgstab_kernels DISABLE_EXECUTORS dpcpp hip) +ginkgo_create_common_test(batch_bicgstab_kernels DISABLE_EXECUTORS dpcpp) ginkgo_create_common_test(bicg_kernels) ginkgo_create_common_test(bicgstab_kernels) ginkgo_create_common_test(cb_gmres_kernels) From cf9839a284eaa075440124afdba936cf57d8af50 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Fri, 27 Oct 2023 13:11:30 +0200 Subject: [PATCH 03/28] Add dpcpp kernels Co-authored-by: Phuong Nguyen --- dpcpp/base/batch_multi_vector_kernels.hpp.inc | 100 ++++ dpcpp/matrix/batch_dense_kernels.dp.cpp | 54 +-- dpcpp/matrix/batch_dense_kernels.hpp.inc | 18 +- dpcpp/matrix/batch_ell_kernels.dp.cpp | 54 +-- dpcpp/matrix/batch_ell_kernels.hpp.inc | 21 +- dpcpp/preconditioner/batch_identity.hpp.inc | 2 +- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 222 ++++++++- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 449 ++++++++++++++++++ test/solver/CMakeLists.txt | 2 +- 9 files changed, 834 insertions(+), 88 deletions(-) create mode 100644 dpcpp/solver/batch_bicgstab_kernels.hpp.inc diff --git a/dpcpp/base/batch_multi_vector_kernels.hpp.inc b/dpcpp/base/batch_multi_vector_kernels.hpp.inc index 22d00d780f9..828833b6ea3 100644 --- a/dpcpp/base/batch_multi_vector_kernels.hpp.inc +++ b/dpcpp/base/batch_multi_vector_kernels.hpp.inc @@ -67,6 +67,49 @@ __dpct_inline__ void add_scaled_kernel( } +template +__dpct_inline__ void single_rhs_compute_dot( + const int num_rows, const ValueType* const __restrict__ x, + const ValueType* const __restrict__ y, ValueType& result, + sycl::nd_item<3> item_ct1) +{ + const auto group = item_ct1.get_group(); + const auto group_size = item_ct1.get_local_range().size(); + const auto tid = item_ct1.get_local_linear_id(); + + ValueType val = zero(); + + for (int r = tid; r < num_rows; r += group_size) { + val += conj(x[r]) * y[r]; + } + result = sycl::reduce_over_group(group, val, sycl::plus<>()); +} + + +template +__dpct_inline__ void single_rhs_compute_dot_sg( + const int num_rows, const ValueType* const __restrict__ x, + const ValueType* const __restrict__ y, ValueType& result, + sycl::nd_item<3> item_ct1) +{ + const auto sg = item_ct1.get_sub_group(); + const auto sg_size = sg.get_local_range().size(); + const auto sg_tid = sg.get_local_id(); + + ValueType val = zero(); + + for (int r = sg_tid; r < num_rows; r += sg_size) { + val += conj(x[r]) * y[r]; + } + + val = sycl::reduce_over_group(sg, val, sycl::plus<>()); + + if (sg_tid == 0) { + result = val; + } +} + + template __dpct_inline__ void compute_gen_dot_product_kernel( const gko::batch::multi_vector::batch_item& x, @@ -102,6 +145,52 @@ __dpct_inline__ void compute_gen_dot_product_kernel( } +template +__dpct_inline__ void single_rhs_compute_norm2_sg( + const int num_rows, const ValueType* const __restrict__ x, + gko::remove_complex& result, sycl::nd_item<3> item_ct1) +{ + const auto sg = item_ct1.get_sub_group(); + const auto sg_size = sg.get_local_range().size(); + const auto sg_tid = sg.get_local_id(); + + using real_type = typename gko::remove_complex; + real_type val = zero(); + + for (int r = sg_tid; r < num_rows; r += sg_size) { + val += squared_norm(x[r]); + } + + val = sycl::reduce_over_group(sg, val, sycl::plus<>()); + + if (sg_tid == 0) { + result = sqrt(val); + } +} + + +template +__dpct_inline__ void single_rhs_compute_norm2( + const int num_rows, const ValueType* const __restrict__ x, + gko::remove_complex& result, sycl::nd_item<3> item_ct1) +{ + const auto group = item_ct1.get_group(); + const auto group_size = item_ct1.get_local_range().size(); + const auto tid = item_ct1.get_local_linear_id(); + + using real_type = typename gko::remove_complex; + real_type val = zero(); + + for (int r = tid; r < num_rows; r += group_size) { + val += squared_norm(x[r]); + } + + val = sycl::reduce_over_group(group, val, sycl::plus<>()); + + result = sqrt(val); +} + + template __dpct_inline__ void compute_norm2_kernel( const gko::batch::multi_vector::batch_item& x, @@ -136,6 +225,17 @@ __dpct_inline__ void compute_norm2_kernel( } +template +__dpct_inline__ void copy_kernel(const int num_rows, const ValueType* in, + ValueType* out, sycl::nd_item<3>& item_ct1) +{ + for (int iz = item_ct1.get_local_linear_id(); iz < num_rows; + iz += item_ct1.get_local_range().size()) { + out[iz] = in[iz]; + } +} + + template __dpct_inline__ void copy_kernel( const gko::batch::multi_vector::batch_item& in, diff --git a/dpcpp/matrix/batch_dense_kernels.dp.cpp b/dpcpp/matrix/batch_dense_kernels.dp.cpp index a6fba2df8e3..a80ef047e8d 100644 --- a/dpcpp/matrix/batch_dense_kernels.dp.cpp +++ b/dpcpp/matrix/batch_dense_kernels.dp.cpp @@ -100,17 +100,17 @@ void simple_apply(std::shared_ptr exec, // Launch a kernel that has nbatches blocks, each block has max group size exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - simple_apply_kernel(mat_b, b_b, x_b, item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + simple_apply_kernel(mat_b, b_b.values, x_b.values, item_ct1); + }); }); } @@ -147,22 +147,22 @@ void advanced_apply(std::shared_ptr exec, // Launch a kernel that has nbatches blocks, each block has max group size exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto alpha_b = - batch::extract_batch_item(alpha_ub, group_id); - const auto beta_b = - batch::extract_batch_item(beta_ub, group_id); - advanced_apply_kernel(alpha_b, mat_b, b_b, beta_b, x_b, - item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto alpha_b = + batch::extract_batch_item(alpha_ub, group_id); + const auto beta_b = + batch::extract_batch_item(beta_ub, group_id); + advanced_apply_kernel(alpha_b.values[0], mat_b, b_b.values, + beta_b.values[0], x_b.values, item_ct1); + }); }); } diff --git a/dpcpp/matrix/batch_dense_kernels.hpp.inc b/dpcpp/matrix/batch_dense_kernels.hpp.inc index 88ef5f54764..ba232ea02e4 100644 --- a/dpcpp/matrix/batch_dense_kernels.hpp.inc +++ b/dpcpp/matrix/batch_dense_kernels.hpp.inc @@ -33,9 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. template __dpct_inline__ void simple_apply_kernel( const gko::batch::matrix::dense::batch_item& mat, - const gko::batch::multi_vector::batch_item& b, - const gko::batch::multi_vector::batch_item& x, - sycl::nd_item<3>& item_ct1) + const ValueType* b, ValueType* x, sycl::nd_item<3>& item_ct1) { constexpr auto tile_size = config::warp_size; auto subg = @@ -50,14 +48,14 @@ __dpct_inline__ void simple_apply_kernel( for (int j = subgroup.get_local_id(); j < mat.num_cols; j += subgroup_size) { const ValueType val = mat.values[row * mat.stride + j]; - temp += val * b.values[j]; + temp += val * b[j]; } temp = ::gko::kernels::dpcpp::reduce( subg, temp, [](ValueType a, ValueType b) { return a + b; }); if (subgroup.get_local_id() == 0) { - x.values[row] = temp; + x[row] = temp; } } } @@ -65,11 +63,9 @@ __dpct_inline__ void simple_apply_kernel( template __dpct_inline__ void advanced_apply_kernel( - const gko::batch::multi_vector::batch_item& alpha, + const ValueType alpha, const gko::batch::matrix::dense::batch_item& mat, - const gko::batch::multi_vector::batch_item& b, - const gko::batch::multi_vector::batch_item& beta, - const gko::batch::multi_vector::batch_item& x, + const ValueType* b, const ValueType beta, ValueType* x, sycl::nd_item<3>& item_ct1) { constexpr auto tile_size = config::warp_size; @@ -85,14 +81,14 @@ __dpct_inline__ void advanced_apply_kernel( for (int j = subgroup.get_local_id(); j < mat.num_cols; j += subgroup_size) { const ValueType val = mat.values[row * mat.stride + j]; - temp += alpha.values[0] * val * b.values[j]; + temp += alpha * val * b[j]; } temp = ::gko::kernels::dpcpp::reduce( subg, temp, [](ValueType a, ValueType b) { return a + b; }); if (subgroup.get_local_id() == 0) { - x.values[row] = temp + beta.values[0] * x.values[row]; + x[row] = temp + beta * x[row]; } } } diff --git a/dpcpp/matrix/batch_ell_kernels.dp.cpp b/dpcpp/matrix/batch_ell_kernels.dp.cpp index 5a69bbd3d5d..1ebd41a7e24 100644 --- a/dpcpp/matrix/batch_ell_kernels.dp.cpp +++ b/dpcpp/matrix/batch_ell_kernels.dp.cpp @@ -97,17 +97,17 @@ void simple_apply(std::shared_ptr exec, // Launch a kernel that has nbatches blocks, each block has max group size exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - simple_apply_kernel(mat_b, b_b, x_b, item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + simple_apply_kernel(mat_b, b_b.values, x_b.values, item_ct1); + }); }); } @@ -145,22 +145,22 @@ void advanced_apply(std::shared_ptr exec, // Launch a kernel that has nbatches blocks, each block has max group size exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto alpha_b = - batch::extract_batch_item(alpha_ub, group_id); - const auto beta_b = - batch::extract_batch_item(beta_ub, group_id); - advanced_apply_kernel(alpha_b, mat_b, b_b, beta_b, x_b, - item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto alpha_b = + batch::extract_batch_item(alpha_ub, group_id); + const auto beta_b = + batch::extract_batch_item(beta_ub, group_id); + advanced_apply_kernel(alpha_b.values[0], mat_b, b_b.values, + beta_b.values[0], x_b.values, item_ct1); + }); }); } diff --git a/dpcpp/matrix/batch_ell_kernels.hpp.inc b/dpcpp/matrix/batch_ell_kernels.hpp.inc index 64d71710dbb..8c54d48db7d 100644 --- a/dpcpp/matrix/batch_ell_kernels.hpp.inc +++ b/dpcpp/matrix/batch_ell_kernels.hpp.inc @@ -33,9 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. template __dpct_inline__ void simple_apply_kernel( const gko::batch::matrix::ell::batch_item& mat, - const gko::batch::multi_vector::batch_item& b, - const gko::batch::multi_vector::batch_item& x, - sycl::nd_item<3>& item_ct1) + const ValueType* b, ValueType* x, sycl::nd_item<3>& item_ct1) { for (int tidx = item_ct1.get_local_linear_id(); tidx < mat.num_rows; tidx += item_ct1.get_local_range().size()) { @@ -45,22 +43,19 @@ __dpct_inline__ void simple_apply_kernel( if (col_idx == invalid_index()) { break; } else { - temp += mat.values[tidx + idx * mat.stride] * - b.values[col_idx * b.stride]; + temp += mat.values[tidx + idx * mat.stride] * b[col_idx]; } } - x.values[tidx * x.stride] = temp; + x[tidx] = temp; } } template __dpct_inline__ void advanced_apply_kernel( - const gko::batch::multi_vector::batch_item& alpha, + const ValueType alpha, const gko::batch::matrix::ell::batch_item& mat, - const gko::batch::multi_vector::batch_item& b, - const gko::batch::multi_vector::batch_item& beta, - const gko::batch::multi_vector::batch_item& x, + const ValueType* b, const ValueType beta, ValueType* x, sycl::nd_item<3>& item_ct1) { for (int tidx = item_ct1.get_local_linear_id(); tidx < mat.num_rows; @@ -71,11 +66,9 @@ __dpct_inline__ void advanced_apply_kernel( if (col_idx == invalid_index()) { break; } else { - temp += mat.values[tidx + idx * mat.stride] * - b.values[col_idx * b.stride]; + temp += mat.values[tidx + idx * mat.stride] * b[col_idx]; } } - x.values[tidx * x.stride] = - alpha.values[0] * temp + beta.values[0] * x.values[tidx * x.stride]; + x[tidx] = alpha * temp + beta * x[tidx]; } } diff --git a/dpcpp/preconditioner/batch_identity.hpp.inc b/dpcpp/preconditioner/batch_identity.hpp.inc index e15a4d37399..53e2f70a7d9 100644 --- a/dpcpp/preconditioner/batch_identity.hpp.inc +++ b/dpcpp/preconditioner/batch_identity.hpp.inc @@ -44,7 +44,7 @@ public: void generate(size_type batch_id, const gko::batch::matrix::ell::batch_item&, + gko::int32>&, ValueType* const, sycl::nd_item<3> item_ct1) {} diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 81519d8e2aa..b4cb227fe03 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -33,12 +33,26 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/solver/batch_bicgstab_kernels.hpp" -#include -#include +#include +#include +#include +#include + + +#include "core/base/batch_struct.hpp" +#include "core/matrix/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" +#include "dpcpp/base/batch_struct.hpp" #include "dpcpp/base/config.hpp" +#include "dpcpp/base/dim3.dp.hpp" +#include "dpcpp/base/dpct.hpp" +#include "dpcpp/base/helper.hpp" +#include "dpcpp/components/cooperative_groups.dp.hpp" +#include "dpcpp/components/intrinsics.dp.hpp" +#include "dpcpp/components/reduction.dp.hpp" +#include "dpcpp/components/thread_ids.dp.hpp" #include "dpcpp/matrix/batch_struct.hpp" @@ -46,26 +60,220 @@ namespace gko { namespace kernels { namespace dpcpp { /** - * @brief The batch Bicgstab solver namespace. + * @brief The batch Cg solver namespace. * - * @ingroup batch_bicgstab + * @ingroup batch_cg */ namespace batch_bicgstab { +#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc" +#include "dpcpp/matrix/batch_dense_kernels.hpp.inc" +#include "dpcpp/matrix/batch_ell_kernels.hpp.inc" +#include "dpcpp/solver/batch_bicgstab_kernels.hpp.inc" + + template using settings = gko::kernels::batch_bicgstab::settings; +__dpct_inline__ int get_group_size(int value, int simd_len = 32) +{ + int num_sg = (value + simd_len - 1) / simd_len; + return (num_sg * simd_len); +} + + +template +class KernelCaller { +public: + KernelCaller(std::shared_ptr exec, + const settings> settings) + : exec_{exec}, settings_{settings} + {} + + template + __dpct_inline__ void launch_apply_kernel( + const gko::kernels::batch_bicgstab::storage_config& sconf, + LogType& logger, PrecType& prec, const BatchMatrixType mat, + const ValueType* const __restrict__ b_values, + ValueType* const __restrict__ x_values, + ValueType* const __restrict__ workspace, const int& group_size, + const int& shared_size) const + { + auto num_rows = mat.num_rows; + + const dim3 block(group_size); + const dim3 grid(mat.num_batch_items); + + auto max_iters = settings_.max_iterations; + auto res_tol = settings_.residual_tol; + + (exec_->get_queue())->submit([&](sycl::handler& cgh) { + sycl::accessor + slm_values(sycl::range<1>(shared_size), cgh); + + cgh.parallel_for( + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + simd_len)]] [ + [intel::kernel_args_restrict]] { + auto batch_id = item_ct1.get_group_linear_id(); + const auto mat_global_entry = + gko::batch::matrix::extract_batch_item(mat, batch_id); + const ValueType* const b_global_entry = + gko::batch::multi_vector::batch_item_ptr( + b_values, 1, num_rows, batch_id); + ValueType* const x_global_entry = + gko::batch::multi_vector::batch_item_ptr( + x_values, 1, num_rows, batch_id); + apply_kernel( + sconf, max_iters, res_tol, logger, prec, + mat_global_entry, b_global_entry, x_global_entry, + num_rows, mat.get_single_item_num_nnz(), + static_cast(slm_values.get_pointer()), + item_ct1, workspace); + }); + }); + } + + template + void call_kernel( + LogType logger, const BatchMatrixType& mat, PrecType prec, + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const + { + using real_type = gko::remove_complex; + const size_type num_batch_items = mat.num_batch_items; + const auto num_rows = mat.num_rows; + const auto num_rhs = b.num_rhs; + GKO_ASSERT(num_rhs == 1); + + auto device = exec_->get_queue()->get_device(); + auto group_size = + device.get_info(); + if (group_size > num_rows) group_size = get_group_size(num_rows); + + size_type shmem_per_blk = + device.get_info() - + (group_size + 5) * sizeof(ValueType) - + 2 * sizeof( + real_type); // reserve 5 for intermediate rho-s, norms, + // alpha, omega, temp and for reduce_over_group + if (shmem_per_blk < 0) shmem_per_blk = 0; + const int shared_gap = num_rows; + const size_type prec_size = PrecType::dynamic_work_size( + shared_gap, mat.get_single_item_num_nnz()); + const auto sconf = + gko::kernels::batch_bicgstab::compute_shared_storage( + shmem_per_blk, shared_gap, mat.get_single_item_num_nnz(), + b.num_rhs); + const size_t shared_size = + sconf.n_shared * shared_gap + (sconf.prec_shared ? prec_size : 0); + auto workspace = gko::array( + exec_, + sconf.gmem_stride_bytes * num_batch_items / sizeof(ValueType)); + assert(sconf.gmem_stride_bytes % sizeof(ValueType) == 0); + + ValueType* const workspace_data = workspace.get_data(); + int n_shared_total = sconf.n_shared + int(sconf.prec_shared); + + // template + // launch_apply_kernel + if (num_rows <= 32 && n_shared_total == 10) + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, workspace_data, + group_size, shared_size); + else if (num_rows <= 256 && n_shared_total == 10) + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, workspace_data, + group_size, shared_size); + else { + switch (n_shared_total) { + case 0: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 1: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 2: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 3: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 4: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 5: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 6: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 7: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 8: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 9: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 10: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + } + } + } + +private: + std::shared_ptr exec_; + const settings> settings_; +}; + + template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* const a, - const batch::BatchLinOp* const precon, + const batch::BatchLinOp* const mat, + const batch::BatchLinOp* const precond, const batch::MultiVector* const b, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) - GKO_NOT_IMPLEMENTED; +{ + auto dispatcher = batch::solver::create_dispatcher( + KernelCaller(exec, settings), settings, mat, precond); + dispatcher.apply(b, x, logdata); +} + GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc new file mode 100644 index 00000000000..c7ad625b9af --- /dev/null +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -0,0 +1,449 @@ +/************************************************************* +Copyright (c) 2017-2023, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +template +__dpct_inline__ void initialize( + const int num_rows, const BatchMatrixType_entry& mat_global_entry, + const ValueType* const b_global_entry, + const ValueType* const x_global_entry, ValueType& rho_old, ValueType& omega, + ValueType& alpha, ValueType* const x_shared_entry, + ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry, + ValueType* const p_shared_entry, ValueType* const v_shared_entry, + typename gko::remove_complex& rhs_norm, + typename gko::remove_complex& res_norm, + sycl::nd_item<3> item_ct1) +{ + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); + const auto group_size = item_ct1.get_local_range().size(); + const auto group = item_ct1.get_group(); + + rho_old = one(); + omega = one(); + alpha = one(); + + // copy x from global to shared memory + // r = b + for (int iz = tid; iz < num_rows; iz += group_size) { + x_shared_entry[iz] = x_global_entry[iz]; + r_shared_entry[iz] = b_global_entry[iz]; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + // r = b - A*x + advanced_apply_kernel(static_cast(-1.0), mat_global_entry, + x_shared_entry, static_cast(1.0), + r_shared_entry, item_ct1); + item_ct1.barrier(sycl::access::fence_space::local_space); + + if constexpr (sg_kernel_all) { + if (sg_id == 0) { + single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm, + item_ct1); + } else if (sg_id == 1) { + single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm, + item_ct1); + } + } else { + single_rhs_compute_norm2(num_rows, r_shared_entry, res_norm, item_ct1); + single_rhs_compute_norm2(num_rows, b_global_entry, rhs_norm, item_ct1); + } + + + for (int iz = tid; iz < num_rows; iz += group_size) { + r_hat_shared_entry[iz] = r_shared_entry[iz]; + p_shared_entry[iz] = zero(); + v_shared_entry[iz] = zero(); + } +} + + +template +__dpct_inline__ void update_p(const int num_rows, const ValueType& rho_new, + const ValueType& rho_old, const ValueType& alpha, + const ValueType& omega, + const ValueType* const r_shared_entry, + const ValueType* const v_shared_entry, + ValueType* const p_shared_entry, + sycl::nd_item<3> item_ct1) +{ + const ValueType beta = (rho_new / rho_old) * (alpha / omega); + for (int r = item_ct1.get_local_linear_id(); r < num_rows; + r += item_ct1.get_local_range().size()) { + p_shared_entry[r] = + r_shared_entry[r] + + beta * (p_shared_entry[r] - omega * v_shared_entry[r]); + } +} + +template +__dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, + const ValueType* const r_hat_shared_entry, + const ValueType* const v_shared_entry, + ValueType& alpha, sycl::nd_item<3> item_ct1) +{ + if constexpr (sg_kernel_all) { + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); + + if (sg_id == 0) { + single_rhs_compute_dot_sg(num_rows, r_hat_shared_entry, + v_shared_entry, alpha, item_ct1); + } + if (tid == 0) { + alpha = rho_new / alpha; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + } else { + single_rhs_compute_dot(num_rows, r_hat_shared_entry, v_shared_entry, + alpha, item_ct1); + alpha = rho_new / alpha; + } +} + + +template +__dpct_inline__ void update_s(const int num_rows, + const ValueType* const r_shared_entry, + const ValueType& alpha, + const ValueType* const v_shared_entry, + ValueType* const s_shared_entry, + sycl::nd_item<3> item_ct1) +{ + for (int r = item_ct1.get_local_linear_id(); r < num_rows; + r += item_ct1.get_local_range().size()) { + s_shared_entry[r] = r_shared_entry[r] - alpha * v_shared_entry[r]; + } +} + + +template +__dpct_inline__ void compute_omega(const int num_rows, + const ValueType* const t_shared_entry, + const ValueType* const s_shared_entry, + ValueType& temp, ValueType& omega, + sycl::nd_item<3> item_ct1) +{ + if constexpr (sg_kernel_all) { + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); + + if (sg_id == 0) + single_rhs_compute_dot_sg(num_rows, t_shared_entry, + s_shared_entry, omega, item_ct1); + else if (sg_id == 1) + single_rhs_compute_dot_sg(num_rows, t_shared_entry, + t_shared_entry, temp, item_ct1); + item_ct1.barrier(sycl::access::fence_space::local_space); + if (tid == 0) omega /= temp; + item_ct1.barrier(sycl::access::fence_space::local_space); + } else { + single_rhs_compute_dot(num_rows, t_shared_entry, s_shared_entry, + omega, item_ct1); + single_rhs_compute_dot(num_rows, t_shared_entry, t_shared_entry, + temp, item_ct1); + omega /= temp; + } +} + +template +__dpct_inline__ void update_x_and_r( + const int num_rows, const ValueType* const p_hat_shared_entry, + const ValueType* const s_hat_shared_entry, const ValueType& alpha, + const ValueType& omega, const ValueType* const s_shared_entry, + const ValueType* const t_shared_entry, ValueType* const x_shared_entry, + ValueType* const r_shared_entry, sycl::nd_item<3> item_ct1) +{ + for (int r = item_ct1.get_local_linear_id(); r < num_rows; + r += item_ct1.get_local_range().size()) { + x_shared_entry[r] = x_shared_entry[r] + alpha * p_hat_shared_entry[r] + + omega * s_hat_shared_entry[r]; + r_shared_entry[r] = s_shared_entry[r] - omega * t_shared_entry[r]; + } +} + + +template +__dpct_inline__ void update_x_middle(const int num_rows, const ValueType& alpha, + const ValueType* const p_hat_shared_entry, + ValueType* const x_shared_entry, + sycl::nd_item<3> item_ct1) +{ + for (int r = item_ct1.get_local_linear_id(); r < num_rows; + r += item_ct1.get_local_range().size()) { + x_shared_entry[r] = x_shared_entry[r] + alpha * p_hat_shared_entry[r]; + } +} + + +template +void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, + const int max_iter, const gko::remove_complex tol, + LogType logger, PrecType prec_shared, + const BatchMatrixType mat_global_entry, + const ValueType* const __restrict__ b_global_entry, + ValueType* const __restrict__ x_global_entry, + const size_type num_rows, const size_type nnz, + ValueType* const __restrict__ slm_values, + sycl::nd_item<3> item_ct1, + ValueType* const __restrict__ workspace = nullptr) +{ + using real_type = typename gko::remove_complex; + + const auto sg = item_ct1.get_sub_group(); + const int sg_id = sg.get_group_id(); + const int tid = item_ct1.get_local_linear_id(); + auto group = item_ct1.get_group(); + const int group_size = item_ct1.get_local_range().size(); + + const auto batch_id = item_ct1.get_group_linear_id(); + + ValueType* rho_old_sh; + ValueType* rho_new_sh; + ValueType* alpha_sh; + ValueType* omega_sh; + ValueType* temp_sh; + real_type* norms_rhs_sh; + real_type* norms_res_sh; + + if constexpr (sg_kernel_all) { + using tile_value_t = ValueType[5]; + tile_value_t& values = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + group); + using tile_real_t = real_type[2]; + tile_real_t& reals = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + group); + rho_old_sh = &values[0]; + rho_new_sh = &values[1]; + alpha_sh = &values[2]; + omega_sh = &values[3]; + temp_sh = &values[4]; + norms_rhs_sh = &reals[0]; + norms_res_sh = &reals[1]; + } else { + ValueType values[5]; + real_type reals[2]; + rho_old_sh = &values[0]; + rho_new_sh = &values[1]; + alpha_sh = &values[2]; + omega_sh = &values[3]; + temp_sh = &values[4]; + norms_rhs_sh = &reals[0]; + norms_res_sh = &reals[1]; + } + const int gmem_offset = + batch_id * sconf.gmem_stride_bytes / sizeof(ValueType); + ValueType* p_hat_sh; + ValueType* s_hat_sh; + ValueType* s_sh; + ValueType* p_sh; + ValueType* r_sh; + ValueType* r_hat_sh; + ValueType* v_sh; + ValueType* t_sh; + ValueType* x_sh; + ValueType* prec_work_sh; + + if constexpr (n_shared_total >= 1) { + p_hat_sh = slm_values; + } else { + p_hat_sh = workspace + gmem_offset; + } + if constexpr (n_shared_total == 1) { + s_hat_sh = workspace + gmem_offset; + } else { + s_hat_sh = p_hat_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 2) { + v_sh = workspace + gmem_offset; + } else { + v_sh = s_hat_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 3) { + t_sh = workspace + gmem_offset; + } else { + t_sh = v_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 4) { + p_sh = workspace + gmem_offset; + } else { + p_sh = t_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 5) { + s_sh = workspace + gmem_offset; + } else { + s_sh = p_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 6) { + r_sh = workspace + gmem_offset; + } else { + r_sh = s_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 7) { + r_hat_sh = workspace + gmem_offset; + } else { + r_hat_sh = r_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 8) { + x_sh = workspace + gmem_offset; + } else { + x_sh = r_hat_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 9) { + prec_work_sh = workspace + gmem_offset; + } else { + prec_work_sh = x_sh + sconf.padded_vec_len; + } + + // generate preconditioner + prec_shared.generate(batch_id, mat_global_entry, prec_work_sh, item_ct1); + + // initialization + // rho_old = 1, omega = 1, alpha = 1 + // compute b norms + // copy x from global to shared memory + // r = b - A*x + // compute residual norms + // r_hat = r + // p = 0 + // v = 0 + initialize(num_rows, mat_global_entry, b_global_entry, + x_global_entry, rho_old_sh[0], omega_sh[0], + alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, v_sh, + norms_rhs_sh[0], norms_res_sh[0], item_ct1); + item_ct1.barrier(sycl::access::fence_space::local_space); + + // stopping criterion object + StopType stop(tol, norms_rhs_sh); + + int iter = 0; + for (; iter < max_iter; iter++) { + if (stop.check_converged(norms_res_sh)) { + break; + } + + // rho_new = < r_hat , r > = (r_hat)' * (r) + if constexpr (sg_kernel_all) { + if (sg_id == 0) { + single_rhs_compute_dot_sg(num_rows, r_hat_sh, r_sh, + rho_new_sh[0], item_ct1); + } + item_ct1.barrier(sycl::access::fence_space::local_space); + } else { + single_rhs_compute_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], + item_ct1); + } + + // beta = (rho_new / rho_old)*(alpha / omega) + // p = r + beta*(p - omega * v) + update_p(num_rows, rho_new_sh[0], rho_old_sh[0], alpha_sh[0], omega_sh[0], + r_sh, v_sh, p_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::local_space); + + // p_hat = precond * p + prec_shared.apply(num_rows, p_sh, p_hat_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::local_space); + + // v = A * p_hat + simple_apply_kernel(mat_global_entry, p_hat_sh, v_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::local_space); + + // alpha = rho_new / < r_hat , v> + compute_alpha(num_rows, rho_new_sh[0], r_hat_sh, v_sh, + alpha_sh[0], item_ct1); + // item_ct1.barrier(sycl::access::fence_space::local_space); + + // s = r - alpha*v + update_s(num_rows, r_sh, alpha_sh[0], v_sh, s_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::local_space); + + // an estimate of residual norms + if constexpr (sg_kernel_all) { + if (sg_id == 0) { + single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], item_ct1); + } + item_ct1.barrier(sycl::access::fence_space::local_space); + } else { + single_rhs_compute_norm2(num_rows, s_sh, norms_res_sh[0], item_ct1); + } + + // if (norms_res_sh[0] / norms_rhs_sh[0] < tol) { + if (stop.check_converged(norms_res_sh)) { + update_x_middle(num_rows, alpha_sh[0], p_hat_sh, x_sh, item_ct1); + break; + } + + // s_hat = precond * s + prec_shared.apply(num_rows, s_sh, s_hat_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::local_space); + + // t = A * s_hat + simple_apply_kernel(mat_global_entry, s_hat_sh, t_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::local_space); + + // omega = / + compute_omega(num_rows, t_sh, s_sh, temp_sh[0], omega_sh[0], + item_ct1); + // item_ct1.barrier(sycl::access::fence_space::local_space); + + // x = x + alpha*p_hat + omega *s_hat + // r = s - omega * t + update_x_and_r(num_rows, p_hat_sh, s_hat_sh, alpha_sh[0], omega_sh[0], + s_sh, t_sh, x_sh, r_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::local_space); + + if constexpr (sg_kernel_all) { + if (sg_id == 0) + single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0], item_ct1); + if (tid == group_size - 1) { + rho_old_sh[0] = rho_new_sh[0]; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + } else { + single_rhs_compute_norm2(num_rows, r_sh, norms_res_sh[0], item_ct1); + rho_old_sh[0] = rho_new_sh[0]; + } + } + + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + + // copy x back to global memory + copy_kernel(num_rows, x_sh, x_global_entry, item_ct1); +} diff --git a/test/solver/CMakeLists.txt b/test/solver/CMakeLists.txt index de3430393ae..00c78eb93a0 100644 --- a/test/solver/CMakeLists.txt +++ b/test/solver/CMakeLists.txt @@ -1,4 +1,4 @@ -ginkgo_create_common_test(batch_bicgstab_kernels DISABLE_EXECUTORS dpcpp) +ginkgo_create_common_test(batch_bicgstab_kernels) ginkgo_create_common_test(bicg_kernels) ginkgo_create_common_test(bicgstab_kernels) ginkgo_create_common_test(cb_gmres_kernels) From 1ef1f68d036b2654163446999e09706017b6860e Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Sat, 28 Oct 2023 22:07:31 +0200 Subject: [PATCH 04/28] Fix dpcpp kernel issues --- dpcpp/base/batch_multi_vector_kernels.hpp.inc | 45 ++++++++++------ dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 24 +++++---- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 52 ++++++++++--------- 3 files changed, 70 insertions(+), 51 deletions(-) diff --git a/dpcpp/base/batch_multi_vector_kernels.hpp.inc b/dpcpp/base/batch_multi_vector_kernels.hpp.inc index 828833b6ea3..4db1dc5e1d7 100644 --- a/dpcpp/base/batch_multi_vector_kernels.hpp.inc +++ b/dpcpp/base/batch_multi_vector_kernels.hpp.inc @@ -67,12 +67,15 @@ __dpct_inline__ void add_scaled_kernel( } -template +template __dpct_inline__ void single_rhs_compute_dot( const int num_rows, const ValueType* const __restrict__ x, const ValueType* const __restrict__ y, ValueType& result, sycl::nd_item<3> item_ct1) { + // auto grp = + // group::tiled_partition(group::this_thread_block(item_ct1)); + // auto grp = group::this_thread_block(item_ct1); const auto group = item_ct1.get_group(); const auto group_size = item_ct1.get_local_range().size(); const auto tid = item_ct1.get_local_linear_id(); @@ -86,25 +89,29 @@ __dpct_inline__ void single_rhs_compute_dot( } -template +template __dpct_inline__ void single_rhs_compute_dot_sg( const int num_rows, const ValueType* const __restrict__ x, const ValueType* const __restrict__ y, ValueType& result, sycl::nd_item<3> item_ct1) { - const auto sg = item_ct1.get_sub_group(); - const auto sg_size = sg.get_local_range().size(); - const auto sg_tid = sg.get_local_id(); + auto subg = + group::tiled_partition(group::this_thread_block(item_ct1)); + const auto subgroup = static_cast(subg); + const int subgroup_id = subgroup.get_group_id(); + const int subgroup_size = subgroup.get_local_range().size(); + const auto subgroup_tid = subgroup.get_local_id(); ValueType val = zero(); - for (int r = sg_tid; r < num_rows; r += sg_size) { + for (int r = subgroup_tid; r < num_rows; r += subgroup_size) { val += conj(x[r]) * y[r]; } - val = sycl::reduce_over_group(sg, val, sycl::plus<>()); + val = ::gko::kernels::dpcpp::reduce( + subg, val, [](ValueType a, ValueType b) { return a + b; }); - if (sg_tid == 0) { + if (subgroup_tid == 0) { result = val; } } @@ -145,25 +152,27 @@ __dpct_inline__ void compute_gen_dot_product_kernel( } -template +template __dpct_inline__ void single_rhs_compute_norm2_sg( const int num_rows, const ValueType* const __restrict__ x, gko::remove_complex& result, sycl::nd_item<3> item_ct1) { - const auto sg = item_ct1.get_sub_group(); - const auto sg_size = sg.get_local_range().size(); - const auto sg_tid = sg.get_local_id(); + auto subg = + group::tiled_partition(group::this_thread_block(item_ct1)); + const auto subgroup = static_cast(subg); + const int subgroup_id = subgroup.get_group_id(); + const int subgroup_size = subgroup.get_local_range().size(); using real_type = typename gko::remove_complex; real_type val = zero(); - for (int r = sg_tid; r < num_rows; r += sg_size) { + for (int r = subgroup.get_local_id(); r < num_rows; r += subgroup_size) val += squared_norm(x[r]); - } - val = sycl::reduce_over_group(sg, val, sycl::plus<>()); + val = ::gko::kernels::dpcpp::reduce( + subg, val, [](real_type a, real_type b) { return a + b; }); - if (sg_tid == 0) { + if (subgroup.get_local_id() == 0) { result = sqrt(val); } } @@ -174,6 +183,8 @@ __dpct_inline__ void single_rhs_compute_norm2( const int num_rows, const ValueType* const __restrict__ x, gko::remove_complex& result, sycl::nd_item<3> item_ct1) { + // auto grp = + // group::tiled_partition(group::this_thread_block(item_ct1)); const auto group = item_ct1.get_group(); const auto group_size = item_ct1.get_local_range().size(); const auto tid = item_ct1.get_local_linear_id(); @@ -186,6 +197,8 @@ __dpct_inline__ void single_rhs_compute_norm2( } val = sycl::reduce_over_group(group, val, sycl::plus<>()); + // val = ::gko::kernels::dpcpp::reduce( + // grp, val, [](real_type a, real_type b) { return a + b; }); result = sqrt(val); } diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index b4cb227fe03..61c888b357b 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -196,60 +196,62 @@ class KernelCaller { else { switch (n_shared_total) { case 0: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 1: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 2: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 3: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 4: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 5: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 6: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 7: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 8: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 9: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 10: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; + default: + GKO_NOT_IMPLEMENTED; } } } diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index c7ad625b9af..e0affebe3c2 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -63,17 +63,17 @@ __dpct_inline__ void initialize( // r = b - A*x advanced_apply_kernel(static_cast(-1.0), mat_global_entry, - x_shared_entry, static_cast(1.0), - r_shared_entry, item_ct1); + x_shared_entry, static_cast(1.0), + r_shared_entry, item_ct1); item_ct1.barrier(sycl::access::fence_space::local_space); if constexpr (sg_kernel_all) { if (sg_id == 0) { single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm, - item_ct1); + item_ct1); } else if (sg_id == 1) { single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm, - item_ct1); + item_ct1); } } else { single_rhs_compute_norm2(num_rows, r_shared_entry, res_norm, item_ct1); @@ -120,7 +120,7 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, if (sg_id == 0) { single_rhs_compute_dot_sg(num_rows, r_hat_shared_entry, - v_shared_entry, alpha, item_ct1); + v_shared_entry, alpha, item_ct1); } if (tid == 0) { alpha = rho_new / alpha; @@ -128,7 +128,7 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, item_ct1.barrier(sycl::access::fence_space::local_space); } else { single_rhs_compute_dot(num_rows, r_hat_shared_entry, v_shared_entry, - alpha, item_ct1); + alpha, item_ct1); alpha = rho_new / alpha; } } @@ -162,19 +162,19 @@ __dpct_inline__ void compute_omega(const int num_rows, const auto tid = item_ct1.get_local_linear_id(); if (sg_id == 0) - single_rhs_compute_dot_sg(num_rows, t_shared_entry, - s_shared_entry, omega, item_ct1); + single_rhs_compute_dot_sg(num_rows, t_shared_entry, s_shared_entry, + omega, item_ct1); else if (sg_id == 1) - single_rhs_compute_dot_sg(num_rows, t_shared_entry, - t_shared_entry, temp, item_ct1); + single_rhs_compute_dot_sg(num_rows, t_shared_entry, t_shared_entry, + temp, item_ct1); item_ct1.barrier(sycl::access::fence_space::local_space); if (tid == 0) omega /= temp; item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, t_shared_entry, s_shared_entry, - omega, item_ct1); - single_rhs_compute_dot(num_rows, t_shared_entry, t_shared_entry, - temp, item_ct1); + single_rhs_compute_dot(num_rows, t_shared_entry, s_shared_entry, omega, + item_ct1); + single_rhs_compute_dot(num_rows, t_shared_entry, t_shared_entry, temp, + item_ct1); omega /= temp; } } @@ -356,6 +356,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, int iter = 0; for (; iter < max_iter; iter++) { if (stop.check_converged(norms_res_sh)) { + logger.log_iteration(batch_id, iter, norms_res_sh[0]); break; } @@ -363,18 +364,18 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, if constexpr (sg_kernel_all) { if (sg_id == 0) { single_rhs_compute_dot_sg(num_rows, r_hat_sh, r_sh, - rho_new_sh[0], item_ct1); + rho_new_sh[0], item_ct1); } item_ct1.barrier(sycl::access::fence_space::local_space); } else { single_rhs_compute_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], - item_ct1); + item_ct1); } // beta = (rho_new / rho_old)*(alpha / omega) // p = r + beta*(p - omega * v) - update_p(num_rows, rho_new_sh[0], rho_old_sh[0], alpha_sh[0], omega_sh[0], - r_sh, v_sh, p_sh, item_ct1); + update_p(num_rows, rho_new_sh[0], rho_old_sh[0], alpha_sh[0], + omega_sh[0], r_sh, v_sh, p_sh, item_ct1); item_ct1.barrier(sycl::access::fence_space::local_space); // p_hat = precond * p @@ -388,7 +389,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // alpha = rho_new / < r_hat , v> compute_alpha(num_rows, rho_new_sh[0], r_hat_sh, v_sh, alpha_sh[0], item_ct1); - // item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::local_space); // s = r - alpha*v update_s(num_rows, r_sh, alpha_sh[0], v_sh, s_sh, item_ct1); @@ -397,7 +398,8 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // an estimate of residual norms if constexpr (sg_kernel_all) { if (sg_id == 0) { - single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], item_ct1); + single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], + item_ct1); } item_ct1.barrier(sycl::access::fence_space::local_space); } else { @@ -407,6 +409,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // if (norms_res_sh[0] / norms_rhs_sh[0] < tol) { if (stop.check_converged(norms_res_sh)) { update_x_middle(num_rows, alpha_sh[0], p_hat_sh, x_sh, item_ct1); + logger.log_iteration(batch_id, iter, norms_res_sh[0]); break; } @@ -419,9 +422,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, item_ct1.barrier(sycl::access::fence_space::local_space); // omega = / - compute_omega(num_rows, t_sh, s_sh, temp_sh[0], omega_sh[0], - item_ct1); - // item_ct1.barrier(sycl::access::fence_space::local_space); + compute_omega(num_rows, t_sh, s_sh, temp_sh[0], + omega_sh[0], item_ct1); + item_ct1.barrier(sycl::access::fence_space::local_space); // x = x + alpha*p_hat + omega *s_hat // r = s - omega * t @@ -431,7 +434,8 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, if constexpr (sg_kernel_all) { if (sg_id == 0) - single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0], item_ct1); + single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0], + item_ct1); if (tid == group_size - 1) { rho_old_sh[0] = rho_new_sh[0]; } From 6bcdd570c0f8f35c4103213eab490c19b9e613ef Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Sun, 29 Oct 2023 14:03:36 +0100 Subject: [PATCH 05/28] add mvec single rhs specializations --- dpcpp/base/batch_multi_vector_kernels.dp.cpp | 151 +++++++++++++------ dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 39 +++-- test/base/batch_multi_vector_kernels.cpp | 35 ++++- test/matrix/batch_dense_kernels.cpp | 18 ++- 4 files changed, 171 insertions(+), 72 deletions(-) diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index e0bc15fdc61..0c18fd80806 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -184,28 +184,57 @@ void compute_dot(std::shared_ptr exec, const auto res_ub = get_batch_struct(result); const auto num_batches = x_ub.num_batch_items; + const auto num_rows = x_ub.num_rows; auto device = exec->get_queue()->get_device(); - auto group_size = - device.get_info(); - const dim3 block(group_size); - const dim3 grid(num_batches); - // TODO: Remove reqd_sub_group size and use sycl::reduce_over_group - exec->get_queue()->submit([&](sycl::handler& cgh) { - cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto y_b = batch::extract_batch_item(y_ub, group_id); - const auto res_b = batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel(x_b, y_b, res_b, item_ct1, - [](auto val) { return val; }); - }); - }); + if (x->get_common_size()[1] == 1) { + int group_size = ((num_rows + 32 - 1) / 32) * 32; + + const dim3 block(group_size); + const dim3 grid(num_batches); + + exec->get_queue()->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_dot_sg(x_b.num_rows, x_b.values, + y_b.values, res_b.values[0], + item_ct1); + }); + }); + } else { + auto group_size = + device.get_info(); + + const dim3 block(group_size); + const dim3 grid(num_batches); + + // TODO: Remove reqd_sub_group size and use sycl::reduce_over_group + exec->get_queue()->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return val; }); + }); + }); + } } GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( @@ -232,19 +261,18 @@ void compute_conj_dot(std::shared_ptr exec, exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto y_b = batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( - x_b, y_b, res_b, item_ct1, - [](auto val) { return conj(val); }); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return conj(val); }); + }); }); } @@ -261,26 +289,51 @@ void compute_norm2(std::shared_ptr exec, const auto res_ub = get_batch_struct(result); const auto num_batches = x_ub.num_batch_items; + const auto num_rows = x->get_common_size()[0]; auto device = exec->get_queue()->get_device(); - auto group_size = - device.get_info(); - const dim3 block(group_size); - const dim3 grid(num_batches); + if (x->get_common_size()[1] == 1) { + int group_size = ((num_rows + 32 - 1) / 32) * 32; - exec->get_queue()->submit([&](sycl::handler& cgh) { - cgh.parallel_for(sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto res_b = batch::extract_batch_item( - res_ub, group_id); - compute_norm2_kernel(x_b, res_b, item_ct1); - }); - }); + const dim3 block(group_size); + const dim3 grid(num_batches); + + exec->get_queue()->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, + res_b.values[0], item_ct1); + }); + }); + + } else { + auto group_size = + device.get_info(); + + const dim3 block(group_size); + const dim3 grid(num_batches); + + exec->get_queue()->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_norm2_kernel(x_b, res_b, item_ct1); + }); + }); + } } GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index e0affebe3c2..a32f6f39da8 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -30,6 +30,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *************************************************************/ + template __dpct_inline__ void initialize( @@ -76,9 +77,11 @@ __dpct_inline__ void initialize( item_ct1); } } else { - single_rhs_compute_norm2(num_rows, r_shared_entry, res_norm, item_ct1); - single_rhs_compute_norm2(num_rows, b_global_entry, rhs_norm, item_ct1); + // single_rhs_compute_norm2(num_rows, r_shared_entry, res_norm, + // item_ct1); single_rhs_compute_norm2(num_rows, b_global_entry, + // rhs_norm, item_ct1); } + item_ct1.barrier(sycl::access::fence_space::local_space); for (int iz = tid; iz < num_rows; iz += group_size) { @@ -107,6 +110,7 @@ __dpct_inline__ void update_p(const int num_rows, const ValueType& rho_new, } } + template __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, const ValueType* const r_hat_shared_entry, @@ -127,9 +131,9 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, r_hat_shared_entry, v_shared_entry, - alpha, item_ct1); - alpha = rho_new / alpha; + // single_rhs_compute_dot(num_rows, r_hat_shared_entry, v_shared_entry, + // alpha, item_ct1); + // alpha = rho_new / alpha; } } @@ -171,14 +175,17 @@ __dpct_inline__ void compute_omega(const int num_rows, if (tid == 0) omega /= temp; item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, t_shared_entry, s_shared_entry, omega, - item_ct1); - single_rhs_compute_dot(num_rows, t_shared_entry, t_shared_entry, temp, - item_ct1); - omega /= temp; + // single_rhs_compute_dot(num_rows, t_shared_entry, s_shared_entry, + // omega, + // item_ct1); + // single_rhs_compute_dot(num_rows, t_shared_entry, t_shared_entry, + // temp, + // item_ct1); + // omega /= temp; } } + template __dpct_inline__ void update_x_and_r( const int num_rows, const ValueType* const p_hat_shared_entry, @@ -368,8 +375,8 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], - item_ct1); + // single_rhs_compute_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], + // item_ct1); } // beta = (rho_new / rho_old)*(alpha / omega) @@ -403,10 +410,10 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_norm2(num_rows, s_sh, norms_res_sh[0], item_ct1); + // single_rhs_compute_norm2(num_rows, s_sh, norms_res_sh[0], + // item_ct1); } - // if (norms_res_sh[0] / norms_rhs_sh[0] < tol) { if (stop.check_converged(norms_res_sh)) { update_x_middle(num_rows, alpha_sh[0], p_hat_sh, x_sh, item_ct1); logger.log_iteration(batch_id, iter, norms_res_sh[0]); @@ -441,8 +448,8 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_norm2(num_rows, r_sh, norms_res_sh[0], item_ct1); - rho_old_sh[0] = rho_new_sh[0]; + // single_rhs_compute_norm2(num_rows, r_sh, norms_res_sh[0], + // item_ct1); rho_old_sh[0] = rho_new_sh[0]; } } diff --git a/test/base/batch_multi_vector_kernels.cpp b/test/base/batch_multi_vector_kernels.cpp index be625853656..6f4eb3d05a8 100644 --- a/test/base/batch_multi_vector_kernels.cpp +++ b/test/base/batch_multi_vector_kernels.cpp @@ -70,10 +70,9 @@ class MultiVector : public CommonTestFixture { std::normal_distribution<>(-1.0, 1.0), rand_engine, ref); } - void set_up_vector_data(gko::size_type num_vecs, + void set_up_vector_data(gko::size_type num_vecs, const int num_rows = 252, bool different_alpha = false) { - const int num_rows = 252; x = gen_mtx(batch_size, num_rows, num_vecs); y = gen_mtx(batch_size, num_rows, num_vecs); c_x = gen_mtx(batch_size, num_rows, num_vecs); @@ -143,7 +142,7 @@ TEST_F(MultiVector, MultipleVectorAddScaledIsEquivalentToRef) TEST_F(MultiVector, MultipleVectorAddScaledWithDifferentAlphaIsEquivalentToRef) { - set_up_vector_data(20, true); + set_up_vector_data(20, 252, true); x->add_scaled(alpha.get(), y.get()); dx->add_scaled(dalpha.get(), dy.get()); @@ -185,6 +184,21 @@ TEST_F(MultiVector, MultipleVectorScaleWithDifferentAlphaIsEquivalentToRef) } +TEST_F(MultiVector, ComputeNorm2SingleSmallIsEquivalentToRef) +{ + set_up_vector_data(1, 10); + auto norm_size = + gko::batch_dim<2>(batch_size, gko::dim<2>{1, x->get_common_size()[1]}); + auto norm_expected = NormVector::create(this->ref, norm_size); + auto dnorm = NormVector::create(this->exec, norm_size); + + x->compute_norm2(norm_expected.get()); + dx->compute_norm2(dnorm.get()); + + GKO_ASSERT_BATCH_MTX_NEAR(norm_expected, dnorm, 5 * r::value); +} + + TEST_F(MultiVector, ComputeNorm2SingleIsEquivalentToRef) { set_up_vector_data(1); @@ -250,6 +264,21 @@ TEST_F(MultiVector, ComputeDotSingleIsEquivalentToRef) } +TEST_F(MultiVector, ComputeDotSingleSmallIsEquivalentToRef) +{ + set_up_vector_data(1, 10); + auto dot_size = + gko::batch_dim<2>(batch_size, gko::dim<2>{1, x->get_common_size()[1]}); + auto dot_expected = Mtx::create(this->ref, dot_size); + auto ddot = Mtx::create(this->exec, dot_size); + + x->compute_dot(y.get(), dot_expected.get()); + dx->compute_dot(dy.get(), ddot.get()); + + GKO_ASSERT_BATCH_MTX_NEAR(dot_expected, ddot, 5 * r::value); +} + + TEST_F(MultiVector, ComputeConjDotIsEquivalentToRef) { set_up_vector_data(20); diff --git a/test/matrix/batch_dense_kernels.cpp b/test/matrix/batch_dense_kernels.cpp index a243d51f3c1..1f3967b0eb8 100644 --- a/test/matrix/batch_dense_kernels.cpp +++ b/test/matrix/batch_dense_kernels.cpp @@ -71,9 +71,8 @@ class Dense : public CommonTestFixture { std::normal_distribution<>(-1.0, 1.0), rand_engine, ref); } - void set_up_apply_data(gko::size_type num_vecs = 1) + void set_up_apply_data(int num_rows, gko::size_type num_vecs = 1) { - const int num_rows = 252; const int num_cols = 32; mat = gen_mtx(batch_size, num_rows, num_cols); y = gen_mtx(batch_size, num_cols, num_vecs); @@ -106,9 +105,20 @@ class Dense : public CommonTestFixture { }; +TEST_F(Dense, SingleVectorApplyIsEquivalentToRefForSmallMatrices) +{ + set_up_apply_data(10); + + mat->apply(y.get(), expected.get()); + dmat->apply(dy.get(), dresult.get()); + + GKO_ASSERT_BATCH_MTX_NEAR(dresult, expected, r::value); +} + + TEST_F(Dense, SingleVectorApplyIsEquivalentToRef) { - set_up_apply_data(1); + set_up_apply_data(257); mat->apply(y.get(), expected.get()); dmat->apply(dy.get(), dresult.get()); @@ -119,7 +129,7 @@ TEST_F(Dense, SingleVectorApplyIsEquivalentToRef) TEST_F(Dense, SingleVectorAdvancedApplyIsEquivalentToRef) { - set_up_apply_data(1); + set_up_apply_data(257); mat->apply(alpha.get(), y.get(), beta.get(), expected.get()); dmat->apply(dalpha.get(), dy.get(), dbeta.get(), dresult.get()); From 20fe495724cfa1417b34ebd612fc73a1d9e6b540 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Sun, 29 Oct 2023 15:03:45 +0100 Subject: [PATCH 06/28] minor dpcpp fixes --- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 34 +++++++++------------ 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index a32f6f39da8..38d93d7213f 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -77,9 +77,8 @@ __dpct_inline__ void initialize( item_ct1); } } else { - // single_rhs_compute_norm2(num_rows, r_shared_entry, res_norm, - // item_ct1); single_rhs_compute_norm2(num_rows, b_global_entry, - // rhs_norm, item_ct1); + single_rhs_compute_norm2(num_rows, r_shared_entry, res_norm, item_ct1); + single_rhs_compute_norm2(num_rows, b_global_entry, rhs_norm, item_ct1); } item_ct1.barrier(sycl::access::fence_space::local_space); @@ -131,9 +130,9 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - // single_rhs_compute_dot(num_rows, r_hat_shared_entry, v_shared_entry, - // alpha, item_ct1); - // alpha = rho_new / alpha; + single_rhs_compute_dot(num_rows, r_hat_shared_entry, v_shared_entry, + alpha, item_ct1); + alpha = rho_new / alpha; } } @@ -175,13 +174,11 @@ __dpct_inline__ void compute_omega(const int num_rows, if (tid == 0) omega /= temp; item_ct1.barrier(sycl::access::fence_space::local_space); } else { - // single_rhs_compute_dot(num_rows, t_shared_entry, s_shared_entry, - // omega, - // item_ct1); - // single_rhs_compute_dot(num_rows, t_shared_entry, t_shared_entry, - // temp, - // item_ct1); - // omega /= temp; + single_rhs_compute_dot(num_rows, t_shared_entry, s_shared_entry, omega, + item_ct1); + single_rhs_compute_dot(num_rows, t_shared_entry, t_shared_entry, temp, + item_ct1); + omega /= temp; } } @@ -375,8 +372,8 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - // single_rhs_compute_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], - // item_ct1); + single_rhs_compute_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], + item_ct1); } // beta = (rho_new / rho_old)*(alpha / omega) @@ -410,8 +407,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - // single_rhs_compute_norm2(num_rows, s_sh, norms_res_sh[0], - // item_ct1); + single_rhs_compute_norm2(num_rows, s_sh, norms_res_sh[0], item_ct1); } if (stop.check_converged(norms_res_sh)) { @@ -448,8 +444,8 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - // single_rhs_compute_norm2(num_rows, r_sh, norms_res_sh[0], - // item_ct1); rho_old_sh[0] = rho_new_sh[0]; + single_rhs_compute_norm2(num_rows, r_sh, norms_res_sh[0], item_ct1); + rho_old_sh[0] = rho_new_sh[0]; } } From 2a41fd7a9a1a9342984510b28c0ccabe492c4d6a Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Sun, 29 Oct 2023 16:24:03 +0100 Subject: [PATCH 07/28] Review updates Co-authored-by: Marcel Koch --- .../solver/batch_bicgstab_kernels.hpp.inc | 8 +-- core/solver/batch_bicgstab_kernels.hpp | 9 +++ cuda/base/exception.cuh | 11 ++-- cuda/solver/batch_bicgstab_kernels.cu | 62 +++++++++---------- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 34 +++++----- hip/base/exception.hip.hpp | 10 +-- hip/solver/batch_bicgstab_kernels.hip.cpp | 8 +-- 7 files changed, 76 insertions(+), 66 deletions(-) diff --git a/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc index a4a57d99f01..0f666f205e8 100644 --- a/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc +++ b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc @@ -317,11 +317,11 @@ __global__ void apply_kernel( // alpha = rho_new / < r_hat , v> compute_alpha(subgroup, num_rows, rho_new_sh[0], r_hat_sh, v_sh, - alpha_sh[0] /*, converged*/); + alpha_sh[0]); __syncthreads(); // s = r - alpha*v - update_s(num_rows, r_sh, alpha_sh[0], v_sh, s_sh /*, converged*/); + update_s(num_rows, r_sh, alpha_sh[0], v_sh, s_sh); __syncthreads(); // an estimate of residual norms @@ -348,13 +348,13 @@ __global__ void apply_kernel( // omega = / compute_omega(subgroup, num_rows, t_sh, s_sh, temp_sh[0], - omega_sh[0] /*, converged*/); + omega_sh[0]); __syncthreads(); // x = x + alpha*p_hat + omega *s_hat // r = s - omega * t update_x_and_r(num_rows, p_hat_sh, s_hat_sh, alpha_sh[0], - omega_sh[0], s_sh, t_sh, x_sh, r_sh /*, converged*/); + omega_sh[0], s_sh, t_sh, x_sh, r_sh); __syncthreads(); if (threadIdx.x / config::warp_size == 0) { diff --git a/core/solver/batch_bicgstab_kernels.hpp b/core/solver/batch_bicgstab_kernels.hpp index ccde3aa6826..cd16be76d63 100644 --- a/core/solver/batch_bicgstab_kernels.hpp +++ b/core/solver/batch_bicgstab_kernels.hpp @@ -166,24 +166,33 @@ storage_config compute_shared_storage(const int shared_mem_per_blk, const int prec_storage = Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType); int rem_shared = shared_mem_per_blk; + // Set default values. All vecs are in global. storage_config sconf{false, 0, num_main_vecs, 0, num_rows}; + // If available shared mem, is zero, set all vecs to global. if (rem_shared <= 0) { set_gmem_stride_bytes(sconf, vec_size, prec_storage); return sconf; } + // Compute the number of vecs that can be stored in shared memory and assign + // the rest to global memory. const int initial_vecs_available = rem_shared / vec_size; const int num_vecs_shared = min(initial_vecs_available, num_main_vecs); sconf.n_shared += num_vecs_shared; sconf.n_global -= num_vecs_shared; + // Set the storage configuration with preconditioner workspace in global if + // there are any vectors in global memory. if (sconf.n_global > 0) { set_gmem_stride_bytes(sconf, vec_size, prec_storage); return sconf; } rem_shared -= num_vecs_shared * vec_size; + // If more shared memory space is available and preconditioner workspace is + // needed, enable preconditioner workspace to use shared memory. if (rem_shared >= prec_storage && prec_storage > 0) { sconf.prec_shared = true; rem_shared -= prec_storage; } + // Set the global storage config and align to 32 bytes. set_gmem_stride_bytes(sconf, vec_size, prec_storage); return sconf; } diff --git a/cuda/base/exception.cuh b/cuda/base/exception.cuh index 51dfb63bf72..ccf74ebdb7b 100644 --- a/cuda/base/exception.cuh +++ b/cuda/base/exception.cuh @@ -41,10 +41,12 @@ namespace gko { #define GKO_CUDA_LAST_IF_ERROR_THROW \ - cudaError_t err = cudaGetLastError(); \ - if (err != cudaSuccess) { \ - printf(" Kernel error: %s\n", cudaGetErrorString(err)); \ - throw gko::CudaError(__FILE__, __LINE__, __func__, err); \ + { \ + cudaError_t err = cudaGetLastError(); \ + if (err != cudaSuccess) { \ + printf(" Kernel error: %s\n", cudaGetErrorString(err)); \ + throw gko::CudaError(__FILE__, __LINE__, __func__, err); \ + } \ } \ static_assert(true, \ "This assert is used to counter the false positive extra " \ @@ -53,4 +55,5 @@ namespace gko { } // namespace gko + #endif // GKO_CUDA_BASE_EXCEPTION_CUH_ diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index 07e16535631..d6c3c3ba0fc 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -88,14 +88,11 @@ template exec, const int num_rows) { - int nwarps = num_rows / 4; - if (nwarps < 2) { - nwarps = 2; - } - const int min_block_size = 2 * config::warp_size; + int num_warps = std::max(num_rows / 4, 2); + constexpr int warp_sz = static_cast(config::warp_size); + const int min_block_size = 2 * warp_sz; const int device_max_threads = - ((std::max(num_rows, min_block_size)) / config::warp_size) * - config::warp_size; + ((std::max(num_rows, min_block_size)) / warp_sz) * warp_sz; cudaFuncAttributes funcattr; cudaFuncGetAttributes(&funcattr, apply_kernel exec, const int max_threads_regs = ((max_regs_blk / static_cast((static_cast(num_regs_used)))) / - config::warp_size) * - config::warp_size; + warp_sz) * + warp_sz; int max_threads = std::min(max_threads_regs, device_max_threads); max_threads = max_threads <= 1024 ? max_threads : 1024; - return std::min(nwarps * static_cast(config::warp_size), max_threads); + return std::min(num_warps * warp_sz, max_threads); } template -int get_max_dynamic_shared_memory(std::shared_ptr exec, - const size_type required_cache_storage) +int get_max_dynamic_shared_memory(std::shared_ptr exec) { int shmem_per_sm = 0; cudaDeviceGetAttribute(&shmem_per_sm, @@ -147,7 +143,7 @@ public: KernelCaller(std::shared_ptr exec, const settings> settings) - : exec_{exec}, settings_{settings} + : exec_{std::move(exec)}, settings_{settings} {} template ; const size_type num_batch_items = mat.num_batch_items; constexpr int align_multiple = 8; - const int shared_gap = - ((mat.num_rows + align_multiple - 1) / align_multiple) * - align_multiple; + const int padded_num_rows = + ceildiv(mat.num_rows, align_multiple) * align_multiple; gko::kernels::cuda::configure_shared_memory_banks(); const int shmem_per_blk = get_max_dynamic_shared_memory(exec_, - 0); + BatchMatrixType, value_type>(exec_); const int block_size = get_num_threads_per_block( exec_, mat.num_rows); - assert(block_size >= 2 * config::warp_size); + GKO_ASSERT(block_size >= 2 * config::warp_size); const size_t prec_size = - PrecType::dynamic_work_size(shared_gap, + PrecType::dynamic_work_size(padded_num_rows, mat.get_single_item_num_nnz()) * sizeof(value_type); const auto sconf = gko::kernels::batch_bicgstab::compute_shared_storage( - shmem_per_blk, shared_gap, mat.get_single_item_num_nnz(), + shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), b.num_rhs); const size_t shared_size = - sconf.n_shared * shared_gap * sizeof(value_type) + + sconf.n_shared * padded_num_rows * sizeof(value_type) + (sconf.prec_shared ? prec_size : 0); auto workspace = gko::array( exec_, @@ -213,60 +207,60 @@ public: value_type* const workspace_data = workspace.get_data(); // Template parameters launch_apply_kernel if (sconf.prec_shared) - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); else { switch (sconf.n_shared) { case 0: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 1: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 2: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 3: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 4: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 5: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 6: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 7: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 8: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 9: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 61c888b357b..6c702ef65df 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -79,7 +79,7 @@ using settings = gko::kernels::batch_bicgstab::settings; __dpct_inline__ int get_group_size(int value, int simd_len = 32) { - int num_sg = (value + simd_len - 1) / simd_len; + int num_sg = ceildiv(value, simd_len); return (num_sg * simd_len); } @@ -89,7 +89,7 @@ class KernelCaller { public: KernelCaller(std::shared_ptr exec, const settings> settings) - : exec_{exec}, settings_{settings} + : exec_{std::move(exec)}, settings_{settings} {} template get_queue()->get_device(); auto group_size = device.get_info(); - if (group_size > num_rows) group_size = get_group_size(num_rows); - - size_type shmem_per_blk = + if (group_size > num_rows) { + group_size = get_group_size(num_rows); + }; + + // reserve 5 for intermediate rho-s, norms, + // alpha, omega, temp and for reduce_over_group + // If the value available is negative, then set it to 0 + size_type shmem_per_blk = std::max( device.get_info() - - (group_size + 5) * sizeof(ValueType) - - 2 * sizeof( - real_type); // reserve 5 for intermediate rho-s, norms, - // alpha, omega, temp and for reduce_over_group - if (shmem_per_blk < 0) shmem_per_blk = 0; - const int shared_gap = num_rows; + (group_size + 5) * sizeof(ValueType) - 2 * sizeof(real_type), + static_cast(0)); + const int padded_num_rows = num_rows; const size_type prec_size = PrecType::dynamic_work_size( - shared_gap, mat.get_single_item_num_nnz()); + padded_num_rows, mat.get_single_item_num_nnz()); const auto sconf = gko::kernels::batch_bicgstab::compute_shared_storage( - shmem_per_blk, shared_gap, mat.get_single_item_num_nnz(), + shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), b.num_rhs); - const size_t shared_size = - sconf.n_shared * shared_gap + (sconf.prec_shared ? prec_size : 0); + const size_t shared_size = sconf.n_shared * padded_num_rows + + (sconf.prec_shared ? prec_size : 0); auto workspace = gko::array( exec_, sconf.gmem_stride_bytes * num_batch_items / sizeof(ValueType)); - assert(sconf.gmem_stride_bytes % sizeof(ValueType) == 0); + GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(ValueType) == 0); ValueType* const workspace_data = workspace.get_data(); int n_shared_total = sconf.n_shared + int(sconf.prec_shared); diff --git a/hip/base/exception.hip.hpp b/hip/base/exception.hip.hpp index 7c3b3b2e12e..366f95bffbb 100644 --- a/hip/base/exception.hip.hpp +++ b/hip/base/exception.hip.hpp @@ -41,10 +41,12 @@ namespace gko { #define GKO_HIP_LAST_IF_ERROR_THROW \ - hipError_t err = hipGetLastError(); \ - if (err != hipSuccess) { \ - printf(" Hip kernel error: %s\n", hipGetErrorString(err)); \ - throw gko::HipError(__FILE__, __LINE__, __func__, err); \ + { \ + hipError_t err = hipGetLastError(); \ + if (err != hipSuccess) { \ + printf(" Hip kernel error: %s\n", hipGetErrorString(err)); \ + throw gko::HipError(__FILE__, __LINE__, __func__, err); \ + } \ } \ static_assert(true, \ "This assert is used to counter the false positive extra " \ diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index b9fe8b0c9c3..077b9b5da93 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -147,7 +147,7 @@ class KernelCaller { using real_type = gko::remove_complex; const size_type num_batch_items = mat.num_batch_items; constexpr int align_multiple = 8; - const int shared_gap = + const int padded_num_rows = ((mat.num_rows + align_multiple - 1) / align_multiple) * align_multiple; const int shmem_per_blk = exec_->get_max_shared_memory_per_block(); @@ -156,16 +156,16 @@ class KernelCaller { assert(block_size >= 2 * config::warp_size); const size_t prec_size = - PrecType::dynamic_work_size(shared_gap, + PrecType::dynamic_work_size(padded_num_rows, mat.get_single_item_num_nnz()) * sizeof(value_type); const auto sconf = gko::kernels::batch_bicgstab::compute_shared_storage( - shmem_per_blk, shared_gap, mat.get_single_item_num_nnz(), + shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), b.num_rhs); const size_t shared_size = - sconf.n_shared * shared_gap * sizeof(value_type) + + sconf.n_shared * padded_num_rows * sizeof(value_type) + (sconf.prec_shared ? prec_size : 0); auto workspace = gko::array( exec_, From 5c0f4f4a95bf6fa4b5f90caa091f24234b4fffb0 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Sun, 29 Oct 2023 20:40:48 +0100 Subject: [PATCH 08/28] Fix sycl group and subgroup sizes --- dpcpp/base/batch_multi_vector_kernels.dp.cpp | 89 ++++++++++++-------- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 2 +- 2 files changed, 53 insertions(+), 38 deletions(-) diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index 0c18fd80806..b4dbf1ced31 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -33,6 +33,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/base/batch_multi_vector_kernels.hpp" +#include + + #include @@ -77,10 +80,15 @@ void scale(std::shared_ptr exec, const auto alpha_ub = get_batch_struct(alpha); const auto x_ub = get_batch_struct(x); + const int num_rows = x->get_common_size()[0]; + constexpr int max_subgroup_size = config::warp_size; const auto num_batches = x_ub.num_batch_items; auto device = exec->get_queue()->get_device(); - auto group_size = + long max_group_size = device.get_info(); + int group_size = + std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); const dim3 block(group_size); const dim3 grid(num_batches); @@ -125,13 +133,16 @@ void add_scaled(std::shared_ptr exec, const batch::MultiVector* const x, batch::MultiVector* const y) { - const size_type num_rows = x->get_common_size()[0]; - const size_type num_cols = x->get_common_size()[1]; - + constexpr int max_subgroup_size = config::warp_size; + const int num_rows = x->get_common_size()[0]; + const int num_cols = x->get_common_size()[1]; const auto num_batches = x->get_num_batch_items(); auto device = exec->get_queue()->get_device(); - auto group_size = + long max_group_size = device.get_info(); + int group_size = + std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); const dim3 block(group_size); const dim3 grid(num_batches); @@ -183,22 +194,25 @@ void compute_dot(std::shared_ptr exec, const auto y_ub = get_batch_struct(y); const auto res_ub = get_batch_struct(result); + constexpr int max_subgroup_size = config::warp_size; const auto num_batches = x_ub.num_batch_items; - const auto num_rows = x_ub.num_rows; + const int num_rows = x_ub.num_rows; auto device = exec->get_queue()->get_device(); + long max_group_size = + device.get_info(); + int group_size = + std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); + const dim3 block(group_size); + const dim3 grid(num_batches); if (x->get_common_size()[1] == 1) { - int group_size = ((num_rows + 32 - 1) / 32) * 32; - - const dim3 block(group_size); - const dim3 grid(num_batches); - exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( sycl_nd_range(grid, block), [= ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { + max_subgroup_size)]] { auto group = item_ct1.get_group(); auto group_id = group.get_group_linear_id(); const auto x_b = batch::extract_batch_item(x_ub, group_id); @@ -211,18 +225,12 @@ void compute_dot(std::shared_ptr exec, }); }); } else { - auto group_size = - device.get_info(); - - const dim3 block(group_size); - const dim3 grid(num_batches); - // TODO: Remove reqd_sub_group size and use sycl::reduce_over_group exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( sycl_nd_range(grid, block), [= ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { + max_subgroup_size)]] { auto group = item_ct1.get_group(); auto group_id = group.get_group_linear_id(); const auto x_b = batch::extract_batch_item(x_ub, group_id); @@ -251,10 +259,15 @@ void compute_conj_dot(std::shared_ptr exec, const auto y_ub = get_batch_struct(y); const auto res_ub = get_batch_struct(result); + constexpr int max_subgroup_size = config::warp_size; + const int num_rows = x->get_common_size()[0]; const auto num_batches = x_ub.num_batch_items; auto device = exec->get_queue()->get_device(); - auto group_size = + long max_group_size = device.get_info(); + int group_size = + std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); const dim3 block(group_size); const dim3 grid(num_batches); @@ -263,7 +276,7 @@ void compute_conj_dot(std::shared_ptr exec, cgh.parallel_for( sycl_nd_range(grid, block), [= ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { + max_subgroup_size)]] { auto group = item_ct1.get_group(); auto group_id = group.get_group_linear_id(); const auto x_b = batch::extract_batch_item(x_ub, group_id); @@ -289,20 +302,24 @@ void compute_norm2(std::shared_ptr exec, const auto res_ub = get_batch_struct(result); const auto num_batches = x_ub.num_batch_items; - const auto num_rows = x->get_common_size()[0]; + const int num_rows = x->get_common_size()[0]; auto device = exec->get_queue()->get_device(); - if (x->get_common_size()[1] == 1) { - int group_size = ((num_rows + 32 - 1) / 32) * 32; - - const dim3 block(group_size); - const dim3 grid(num_batches); + constexpr int max_subgroup_size = config::warp_size; + long max_group_size = + device.get_info(); + int group_size = + std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); + const dim3 block(group_size); + const dim3 grid(num_batches); + if (x->get_common_size()[1] == 1) { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( sycl_nd_range(grid, block), [= ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { + max_subgroup_size)]] { auto group = item_ct1.get_group(); auto group_id = group.get_group_linear_id(); const auto x_b = batch::extract_batch_item(x_ub, group_id); @@ -312,19 +329,12 @@ void compute_norm2(std::shared_ptr exec, res_b.values[0], item_ct1); }); }); - } else { - auto group_size = - device.get_info(); - - const dim3 block(group_size); - const dim3 grid(num_batches); - exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( sycl_nd_range(grid, block), [= ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { + max_subgroup_size)]] { auto group = item_ct1.get_group(); auto group_id = group.get_group_linear_id(); const auto x_b = batch::extract_batch_item(x_ub, group_id); @@ -349,9 +359,14 @@ void copy(std::shared_ptr exec, const auto result_ub = get_batch_struct(result); const auto num_batches = x_ub.num_batch_items; + const int num_rows = x->get_common_size()[0]; auto device = exec->get_queue()->get_device(); - auto group_size = + constexpr int max_subgroup_size = config::warp_size; + long max_group_size = device.get_info(); + int group_size = + std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); const dim3 block(group_size); const dim3 grid(num_batches); diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 6c702ef65df..c40d8564d09 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -80,7 +80,7 @@ using settings = gko::kernels::batch_bicgstab::settings; __dpct_inline__ int get_group_size(int value, int simd_len = 32) { int num_sg = ceildiv(value, simd_len); - return (num_sg * simd_len); + return num_sg * simd_len; } From a3fe9bb4b8460e655beaca0a746e5e216ba26c24 Mon Sep 17 00:00:00 2001 From: ginkgo-bot Date: Sun, 29 Oct 2023 19:43:51 +0000 Subject: [PATCH 09/28] Format files Co-authored-by: Pratik Nayak --- .../base/batch_multi_vector_kernels.hpp.inc | 29 ++-- cuda/solver/batch_bicgstab_kernels.cu | 3 +- dpcpp/base/batch_multi_vector_kernels.dp.cpp | 125 +++++++++--------- dpcpp/matrix/batch_dense_kernels.dp.cpp | 56 ++++---- dpcpp/matrix/batch_ell_kernels.dp.cpp | 56 ++++---- dpcpp/preconditioner/batch_identity.hpp.inc | 8 +- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 7 +- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 1 - hip/solver/batch_bicgstab_kernels.hip.cpp | 3 +- 9 files changed, 145 insertions(+), 143 deletions(-) diff --git a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc index 779e2ab0e68..72d58ecf5b3 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc +++ b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc @@ -47,15 +47,10 @@ __device__ __forceinline__ void scale( } template -__global__ __launch_bounds__( - default_block_size, - sm_oversubscription) void scale_kernel(const gko::batch::multi_vector:: - uniform_batch - alpha, - const gko::batch::multi_vector:: - uniform_batch - x, - Mapping map) +__global__ +__launch_bounds__(default_block_size, sm_oversubscription) void scale_kernel( + const gko::batch::multi_vector::uniform_batch alpha, + const gko::batch::multi_vector::uniform_batch x, Mapping map) { for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_items; batch_id += gridDim.x) { @@ -176,11 +171,11 @@ __device__ __forceinline__ void compute_gen_dot_product( template __global__ - __launch_bounds__(default_block_size, sm_oversubscription) void compute_gen_dot_product_kernel( - const gko::batch::multi_vector::uniform_batch x, - const gko::batch::multi_vector::uniform_batch y, - const gko::batch::multi_vector::uniform_batch result, - Mapping map) +__launch_bounds__(default_block_size, sm_oversubscription) void compute_gen_dot_product_kernel( + const gko::batch::multi_vector::uniform_batch x, + const gko::batch::multi_vector::uniform_batch y, + const gko::batch::multi_vector::uniform_batch result, + Mapping map) { for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_items; batch_id += gridDim.x) { @@ -319,9 +314,9 @@ __device__ __forceinline__ void copy( template __global__ - __launch_bounds__(default_block_size, sm_oversubscription) void copy_kernel( - const gko::batch::multi_vector::uniform_batch src, - const gko::batch::multi_vector::uniform_batch dst) +__launch_bounds__(default_block_size, sm_oversubscription) void copy_kernel( + const gko::batch::multi_vector::uniform_batch src, + const gko::batch::multi_vector::uniform_batch dst) { for (size_type batch_id = blockIdx.x; batch_id < src.num_batch_items; batch_id += gridDim.x) { diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index d6c3c3ba0fc..9ecb27aecf2 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -75,9 +75,8 @@ constexpr int sm_oversubscription = 4; namespace batch_bicgstab { -#include "common/cuda_hip/components/uninitialized_array.hpp.inc" - #include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc" +#include "common/cuda_hip/components/uninitialized_array.hpp.inc" #include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc" #include "common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc" diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index b4dbf1ced31..3068b654b75 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -210,37 +210,41 @@ void compute_dot(std::shared_ptr exec, if (x->get_common_size()[1] == 1) { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto y_b = batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_dot_sg(x_b.num_rows, x_b.values, - y_b.values, res_b.values[0], - item_ct1); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = + batch::extract_batch_item(x_ub, group_id); + const auto y_b = + batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_dot_sg(x_b.num_rows, x_b.values, + y_b.values, res_b.values[0], + item_ct1); + }); }); } else { // TODO: Remove reqd_sub_group size and use sycl::reduce_over_group exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto y_b = batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( - x_b, y_b, res_b, item_ct1, - [](auto val) { return val; }); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = + batch::extract_batch_item(x_ub, group_id); + const auto y_b = + batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return val; }); + }); }); } } @@ -274,18 +278,19 @@ void compute_conj_dot(std::shared_ptr exec, exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto y_b = batch::extract_batch_item(y_ub, group_id); - const auto res_b = batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( - x_b, y_b, res_b, item_ct1, - [](auto val) { return conj(val); }); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return conj(val); }); + }); }); } @@ -317,31 +322,33 @@ void compute_norm2(std::shared_ptr exec, if (x->get_common_size()[1] == 1) { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, - res_b.values[0], item_ct1); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = + batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, + res_b.values[0], item_ct1); + }); }); } else { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_norm2_kernel(x_b, res_b, item_ct1); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = + batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_norm2_kernel(x_b, res_b, item_ct1); + }); }); } } diff --git a/dpcpp/matrix/batch_dense_kernels.dp.cpp b/dpcpp/matrix/batch_dense_kernels.dp.cpp index a80ef047e8d..d1320e79968 100644 --- a/dpcpp/matrix/batch_dense_kernels.dp.cpp +++ b/dpcpp/matrix/batch_dense_kernels.dp.cpp @@ -100,17 +100,18 @@ void simple_apply(std::shared_ptr exec, // Launch a kernel that has nbatches blocks, each block has max group size exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - simple_apply_kernel(mat_b, b_b.values, x_b.values, item_ct1); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + simple_apply_kernel(mat_b, b_b.values, x_b.values, + item_ct1); + }); }); } @@ -147,22 +148,23 @@ void advanced_apply(std::shared_ptr exec, // Launch a kernel that has nbatches blocks, each block has max group size exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto alpha_b = - batch::extract_batch_item(alpha_ub, group_id); - const auto beta_b = - batch::extract_batch_item(beta_ub, group_id); - advanced_apply_kernel(alpha_b.values[0], mat_b, b_b.values, - beta_b.values[0], x_b.values, item_ct1); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto alpha_b = + batch::extract_batch_item(alpha_ub, group_id); + const auto beta_b = + batch::extract_batch_item(beta_ub, group_id); + advanced_apply_kernel(alpha_b.values[0], mat_b, b_b.values, + beta_b.values[0], x_b.values, + item_ct1); + }); }); } diff --git a/dpcpp/matrix/batch_ell_kernels.dp.cpp b/dpcpp/matrix/batch_ell_kernels.dp.cpp index 1ebd41a7e24..f565f69f270 100644 --- a/dpcpp/matrix/batch_ell_kernels.dp.cpp +++ b/dpcpp/matrix/batch_ell_kernels.dp.cpp @@ -97,17 +97,18 @@ void simple_apply(std::shared_ptr exec, // Launch a kernel that has nbatches blocks, each block has max group size exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - simple_apply_kernel(mat_b, b_b.values, x_b.values, item_ct1); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + simple_apply_kernel(mat_b, b_b.values, x_b.values, + item_ct1); + }); }); } @@ -145,22 +146,23 @@ void advanced_apply(std::shared_ptr exec, // Launch a kernel that has nbatches blocks, each block has max group size exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto alpha_b = - batch::extract_batch_item(alpha_ub, group_id); - const auto beta_b = - batch::extract_batch_item(beta_ub, group_id); - advanced_apply_kernel(alpha_b.values[0], mat_b, b_b.values, - beta_b.values[0], x_b.values, item_ct1); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto alpha_b = + batch::extract_batch_item(alpha_ub, group_id); + const auto beta_b = + batch::extract_batch_item(beta_ub, group_id); + advanced_apply_kernel(alpha_b.values[0], mat_b, b_b.values, + beta_b.values[0], x_b.values, + item_ct1); + }); }); } diff --git a/dpcpp/preconditioner/batch_identity.hpp.inc b/dpcpp/preconditioner/batch_identity.hpp.inc index 53e2f70a7d9..404d987a3f4 100644 --- a/dpcpp/preconditioner/batch_identity.hpp.inc +++ b/dpcpp/preconditioner/batch_identity.hpp.inc @@ -42,10 +42,10 @@ public: static int dynamic_work_size(int, int) { return 0; } - void generate(size_type batch_id, - const gko::batch::matrix::ell::batch_item&, - ValueType* const, sycl::nd_item<3> item_ct1) + void generate( + size_type batch_id, + const gko::batch::matrix::ell::batch_item&, + ValueType* const, sycl::nd_item<3> item_ct1) {} void generate(size_type batch_id, diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index c40d8564d09..33749e91ae4 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -117,10 +117,9 @@ class KernelCaller { slm_values(sycl::range<1>(shared_size), cgh); cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( - simd_len)]] [ - [intel::kernel_args_restrict]] { + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + simd_len)]] [[intel::kernel_args_restrict]] { auto batch_id = item_ct1.get_group_linear_id(); const auto mat_global_entry = gko::batch::matrix::extract_batch_item(mat, batch_id); diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index 38d93d7213f..67057f80e53 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -30,7 +30,6 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *************************************************************/ - template __dpct_inline__ void initialize( diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index 077b9b5da93..f769a4fe6f5 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -74,9 +74,8 @@ constexpr int sm_oversubscription = 4; namespace batch_bicgstab { -#include "common/cuda_hip/components/uninitialized_array.hpp.inc" - #include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc" +#include "common/cuda_hip/components/uninitialized_array.hpp.inc" #include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc" #include "common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc" From 4072b5099311ee8c3c0bae1e373d194e38fe2bfc Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Mon, 30 Oct 2023 12:13:02 +0100 Subject: [PATCH 10/28] Review updates Co-authored-by: Yu-Hsiang Tsai Co-authored-by: Marcel Koch --- .../base/batch_multi_vector_kernels.hpp.inc | 2 +- .../preconditioner/batch_identity.hpp.inc | 13 +- .../solver/batch_bicgstab_kernels.hpp.inc | 26 ++-- core/solver/batch_bicgstab_kernels.hpp | 18 +-- core/test/utils/batch_helpers.hpp | 18 +-- cuda/matrix/batch_struct.hpp | 3 +- cuda/solver/batch_bicgstab_kernels.cu | 17 +-- dpcpp/base/batch_multi_vector_kernels.dp.cpp | 137 +++++++++--------- dpcpp/base/batch_multi_vector_kernels.hpp.inc | 13 +- dpcpp/matrix/batch_struct.hpp | 2 +- dpcpp/preconditioner/batch_identity.hpp.inc | 12 +- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 45 +++--- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 112 +++++++------- hip/matrix/batch_struct.hip.hpp | 3 +- hip/solver/batch_bicgstab_kernels.hip.cpp | 57 ++++---- reference/matrix/batch_struct.hpp | 2 +- .../test/solver/batch_bicgstab_kernels.cpp | 25 ++-- test/solver/batch_bicgstab_kernels.cpp | 24 +-- 18 files changed, 254 insertions(+), 275 deletions(-) diff --git a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc index 72d58ecf5b3..1e0cb3bbcff 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc +++ b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc @@ -104,7 +104,7 @@ __global__ __launch_bounds__( template -__device__ __forceinline__ void single_rhs_compute_dot(Group subgroup, +__device__ __forceinline__ void single_rhs_compute_conj_dot(Group subgroup, const int num_rows, const ValueType* x, const ValueType* y, diff --git a/common/cuda_hip/preconditioner/batch_identity.hpp.inc b/common/cuda_hip/preconditioner/batch_identity.hpp.inc index 1b1fb7b5482..923ed4ce946 100644 --- a/common/cuda_hip/preconditioner/batch_identity.hpp.inc +++ b/common/cuda_hip/preconditioner/batch_identity.hpp.inc @@ -45,16 +45,9 @@ public: return 0; } - __device__ __forceinline__ void generate( - size_type, - const gko::batch::matrix::ell::batch_item&, - ValueType*) - {} - - __device__ __forceinline__ void generate( - size_type, - const gko::batch::matrix::dense::batch_item&, - ValueType*) + template + __device__ __forceinline__ void generate(size_type, const batch_item_type&, + ValueType*) {} __device__ __forceinline__ void apply(const int num_rows, diff --git a/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc index 0f666f205e8..faee2e069a7 100644 --- a/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc +++ b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc @@ -38,7 +38,8 @@ __device__ __forceinline__ void initialize( const ValueType* const x_global_entry, ValueType& rho_old, ValueType& omega, ValueType& alpha, ValueType* const x_shared_entry, ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry, - ValueType* const p_shared_entry, ValueType* const v_shared_entry, + ValueType* const p_shared_entry, ValueType* const p_hat_shared_entry, + ValueType* const v_shared_entry, typename gko::remove_complex& rhs_norm, typename gko::remove_complex& res_norm) { @@ -70,6 +71,7 @@ __device__ __forceinline__ void initialize( for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) { r_hat_shared_entry[iz] = r_shared_entry[iz]; p_shared_entry[iz] = zero(); + p_hat_shared_entry[iz] = zero(); v_shared_entry[iz] = zero(); } } @@ -82,8 +84,8 @@ __device__ __forceinline__ void update_p( const ValueType* const r_shared_entry, const ValueType* const v_shared_entry, ValueType* const p_shared_entry) { + const ValueType beta = (rho_new / rho_old) * (alpha / omega); for (int r = threadIdx.x; r < num_rows; r += blockDim.x) { - const ValueType beta = (rho_new / rho_old) * (alpha / omega); p_shared_entry[r] = r_shared_entry[r] + beta * (p_shared_entry[r] - omega * v_shared_entry[r]); @@ -97,8 +99,8 @@ __device__ __forceinline__ void compute_alpha( const ValueType* const v_shared_entry, ValueType& alpha) { if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_dot(subgroup, num_rows, r_hat_shared_entry, - v_shared_entry, alpha); + single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry, + v_shared_entry, alpha); } __syncthreads(); if (threadIdx.x == 0) { @@ -126,11 +128,11 @@ __device__ __forceinline__ void compute_omega( const ValueType* const s_shared_entry, ValueType& temp, ValueType& omega) { if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_dot(subgroup, num_rows, t_shared_entry, - s_shared_entry, omega); + single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry, + s_shared_entry, omega); } else if (threadIdx.x / config::warp_size == 1) { - single_rhs_compute_dot(subgroup, num_rows, t_shared_entry, - t_shared_entry, temp); + single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry, + t_shared_entry, temp); } __syncthreads(); @@ -278,10 +280,12 @@ __global__ void apply_kernel( // compute residual norms // r_hat = r // p = 0 + // p_hat = 0 // v = 0 initialize(subgroup, num_rows, mat_entry, b_entry_ptr, x_gl_entry_ptr, rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh, - r_hat_sh, p_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0]); + r_hat_sh, p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], + norms_res_sh[0]); __syncthreads(); // stopping criterion object @@ -296,8 +300,8 @@ __global__ void apply_kernel( // rho_new = < r_hat , r > = (r_hat)' * (r) if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_dot(subgroup, num_rows, r_hat_sh, r_sh, - rho_new_sh[0]); + single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh, r_sh, + rho_new_sh[0]); } __syncthreads(); diff --git a/core/solver/batch_bicgstab_kernels.hpp b/core/solver/batch_bicgstab_kernels.hpp index cd16be76d63..6f5de2e770c 100644 --- a/core/solver/batch_bicgstab_kernels.hpp +++ b/core/solver/batch_bicgstab_kernels.hpp @@ -117,8 +117,7 @@ void set_gmem_stride_bytes(storage_config& sconf, } // align global memory chunks sconf.gmem_stride_bytes = - gmem_stride > 0 ? ((gmem_stride - 1) / align_bytes + 1) * align_bytes - : 0; + gmem_stride > 0 ? ceildiv(gmem_stride, align_bytes) * align_bytes : 0; } @@ -145,8 +144,8 @@ void set_gmem_stride_bytes(storage_config& sconf, * - rhs_norms * - res_norms * - * @param shared_mem_per_blk The amount of shared memory per block to use for - * keeping intermediate vectors. In case keeping the matrix in L1 cache etc. + * @param available_shared_mem The amount of shared memory per block to use + * for keeping intermediate vectors. In case keeping the matrix in L1 cache etc. * should be prioritized, the cache configuration must be updated separately * and the needed space should be subtracted before passing to this * function. @@ -156,7 +155,7 @@ void set_gmem_stride_bytes(storage_config& sconf, * @return A struct containing allocation information specific to Bicgstab. */ template -storage_config compute_shared_storage(const int shared_mem_per_blk, +storage_config compute_shared_storage(const int available_shared_mem, const int num_rows, const int num_nz, const int num_rhs) { @@ -165,10 +164,11 @@ storage_config compute_shared_storage(const int shared_mem_per_blk, const int num_main_vecs = 9; const int prec_storage = Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType); - int rem_shared = shared_mem_per_blk; - // Set default values. All vecs are in global. + int rem_shared = available_shared_mem; + // Set default values. Initially all vecs are in global memory. + // {prec_shared, n_shared, n_global, gmem_stride_bytes, padded_vec_len} storage_config sconf{false, 0, num_main_vecs, 0, num_rows}; - // If available shared mem, is zero, set all vecs to global. + // If available shared mem is zero, set all vecs to global. if (rem_shared <= 0) { set_gmem_stride_bytes(sconf, vec_size, prec_storage); return sconf; @@ -179,13 +179,13 @@ storage_config compute_shared_storage(const int shared_mem_per_blk, const int num_vecs_shared = min(initial_vecs_available, num_main_vecs); sconf.n_shared += num_vecs_shared; sconf.n_global -= num_vecs_shared; + rem_shared -= num_vecs_shared * vec_size; // Set the storage configuration with preconditioner workspace in global if // there are any vectors in global memory. if (sconf.n_global > 0) { set_gmem_stride_bytes(sconf, vec_size, prec_storage); return sconf; } - rem_shared -= num_vecs_shared * vec_size; // If more shared memory space is available and preconditioner workspace is // needed, enable preconditioner workspace to use shared memory. if (rem_shared >= prec_storage && prec_storage > 0) { diff --git a/core/test/utils/batch_helpers.hpp b/core/test/utils/batch_helpers.hpp index 0a6702ff42f..43da4cd9d54 100644 --- a/core/test/utils/batch_helpers.hpp +++ b/core/test/utils/batch_helpers.hpp @@ -224,7 +224,7 @@ struct LinearSystem { std::shared_ptr matrix; std::shared_ptr rhs; - std::shared_ptr rhs_norm; + std::shared_ptr host_rhs_norm; std::shared_ptr exact_sol; }; @@ -250,8 +250,8 @@ LinearSystem generate_batch_linear_system( // A * x_{exact} = b sys.matrix->apply(sys.exact_sol, sys.rhs); const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs)); - sys.rhs_norm = real_vec::create(exec->get_master(), norm_dim); - sys.rhs->compute_norm2(sys.rhs_norm.get()); + sys.host_rhs_norm = real_vec::create(exec->get_master(), norm_dim); + sys.rhs->compute_norm2(sys.host_rhs_norm.get()); return sys; } @@ -273,13 +273,13 @@ compute_residual_norms( const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs)); auto residual_vec = b->clone(); - auto res_norms = real_vec::create(exec->get_master(), norm_dim); + auto res_norm = real_vec::create(exec->get_master(), norm_dim); auto alpha = gko::batch::initialize(num_batch_items, {-1.0}, exec); auto beta = gko::batch::initialize(num_batch_items, {1.0}, exec); mtx->apply(alpha, x, beta, residual_vec); - residual_vec->compute_norm2(res_norms); - return res_norms; + residual_vec->compute_norm2(res_norm); + return res_norm; } @@ -289,7 +289,7 @@ struct Result { using real_vec = batch::MultiVector>; std::shared_ptr x; - std::shared_ptr res_norm; + std::shared_ptr host_res_norm; }; @@ -323,7 +323,7 @@ Result solve_linear_system( result.x->fill(zero()); solver->apply(sys.rhs, result.x); - result.res_norm = + result.host_res_norm = compute_residual_norms(sys.matrix.get(), sys.rhs.get(), result.x.get()); return std::move(result); @@ -369,7 +369,7 @@ ResultWithLogData solve_linear_system( result.log_data->iter_counts = log_data->iter_counts; result.log_data->res_norms = log_data->res_norms; - result.res_norm = + result.host_res_norm = compute_residual_norms(sys.matrix.get(), sys.rhs.get(), result.x.get()); return std::move(result); diff --git a/cuda/matrix/batch_struct.hpp b/cuda/matrix/batch_struct.hpp index 4a2a1835961..55a30c043e3 100644 --- a/cuda/matrix/batch_struct.hpp +++ b/cuda/matrix/batch_struct.hpp @@ -92,7 +92,8 @@ get_batch_struct(batch::matrix::Dense* const op) * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch, IndexType> +inline batch::matrix::ell::uniform_batch, + const IndexType> get_batch_struct(const batch::matrix::Ell* const op) { return {as_cuda_type(op->get_const_values()), diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index 9ecb27aecf2..1c26a1754ba 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -101,13 +101,10 @@ int get_num_threads_per_block(std::shared_ptr exec, cudaDeviceGetAttribute(&max_regs_blk, cudaDevAttrMaxRegistersPerBlock, exec->get_device_id()); const int max_threads_regs = - ((max_regs_blk / - static_cast((static_cast(num_regs_used)))) / - warp_sz) * - warp_sz; + ((max_regs_blk / static_cast(num_regs_used)) / warp_sz) * warp_sz; int max_threads = std::min(max_threads_regs, device_max_threads); max_threads = max_threads <= 1024 ? max_threads : 1024; - return std::min(num_warps * warp_sz, max_threads); + return std::max(std::min(num_warps * warp_sz, max_threads), min_block_size); } @@ -136,12 +133,12 @@ using settings = gko::kernels::batch_bicgstab::settings; template -class KernelCaller { +class kernel_caller { public: using value_type = CuValueType; - KernelCaller(std::shared_ptr exec, - const settings> settings) + kernel_caller(std::shared_ptr exec, + const settings> settings) : exec_{std::move(exec)}, settings_{settings} {} @@ -263,6 +260,8 @@ public: sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; + default: + GKO_NOT_IMPLEMENTED; } } @@ -286,7 +285,7 @@ void apply(std::shared_ptr exec, { using cu_value_type = cuda_type; auto dispatcher = batch::solver::create_dispatcher( - KernelCaller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precon); dispatcher.apply(b, x, logdata); } diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index 3068b654b75..c9809696889 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -87,7 +87,7 @@ void scale(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -141,7 +141,7 @@ void add_scaled(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -202,7 +202,7 @@ void compute_dot(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -210,41 +210,37 @@ void compute_dot(std::shared_ptr exec, if (x->get_common_size()[1] == 1) { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto y_b = - batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_dot_sg(x_b.num_rows, x_b.values, - y_b.values, res_b.values[0], - item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_conj_dot_sg(x_b.num_rows, x_b.values, + y_b.values, res_b.values[0], + item_ct1); + }); }); } else { // TODO: Remove reqd_sub_group size and use sycl::reduce_over_group exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto y_b = - batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( - x_b, y_b, res_b, item_ct1, - [](auto val) { return val; }); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return val; }); + }); }); } } @@ -270,7 +266,7 @@ void compute_conj_dot(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -278,19 +274,18 @@ void compute_conj_dot(std::shared_ptr exec, exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto y_b = batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( - x_b, y_b, res_b, item_ct1, - [](auto val) { return conj(val); }); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return conj(val); }); + }); }); } @@ -314,7 +309,7 @@ void compute_norm2(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -322,33 +317,31 @@ void compute_norm2(std::shared_ptr exec, if (x->get_common_size()[1] == 1) { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, - res_b.values[0], item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, + res_b.values[0], item_ct1); + }); }); } else { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_norm2_kernel(x_b, res_b, item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_norm2_kernel(x_b, res_b, item_ct1); + }); }); } } @@ -372,7 +365,7 @@ void copy(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); diff --git a/dpcpp/base/batch_multi_vector_kernels.hpp.inc b/dpcpp/base/batch_multi_vector_kernels.hpp.inc index 4db1dc5e1d7..1fb5684871d 100644 --- a/dpcpp/base/batch_multi_vector_kernels.hpp.inc +++ b/dpcpp/base/batch_multi_vector_kernels.hpp.inc @@ -67,15 +67,12 @@ __dpct_inline__ void add_scaled_kernel( } -template -__dpct_inline__ void single_rhs_compute_dot( +template +__dpct_inline__ void single_rhs_compute_conj_dot( const int num_rows, const ValueType* const __restrict__ x, const ValueType* const __restrict__ y, ValueType& result, sycl::nd_item<3> item_ct1) { - // auto grp = - // group::tiled_partition(group::this_thread_block(item_ct1)); - // auto grp = group::this_thread_block(item_ct1); const auto group = item_ct1.get_group(); const auto group_size = item_ct1.get_local_range().size(); const auto tid = item_ct1.get_local_linear_id(); @@ -90,7 +87,7 @@ __dpct_inline__ void single_rhs_compute_dot( template -__dpct_inline__ void single_rhs_compute_dot_sg( +__dpct_inline__ void single_rhs_compute_conj_dot_sg( const int num_rows, const ValueType* const __restrict__ x, const ValueType* const __restrict__ y, ValueType& result, sycl::nd_item<3> item_ct1) @@ -183,8 +180,6 @@ __dpct_inline__ void single_rhs_compute_norm2( const int num_rows, const ValueType* const __restrict__ x, gko::remove_complex& result, sycl::nd_item<3> item_ct1) { - // auto grp = - // group::tiled_partition(group::this_thread_block(item_ct1)); const auto group = item_ct1.get_group(); const auto group_size = item_ct1.get_local_range().size(); const auto tid = item_ct1.get_local_linear_id(); @@ -197,8 +192,6 @@ __dpct_inline__ void single_rhs_compute_norm2( } val = sycl::reduce_over_group(group, val, sycl::plus<>()); - // val = ::gko::kernels::dpcpp::reduce( - // grp, val, [](real_type a, real_type b) { return a + b; }); result = sqrt(val); } diff --git a/dpcpp/matrix/batch_struct.hpp b/dpcpp/matrix/batch_struct.hpp index fe04407d82d..7f36378d8e1 100644 --- a/dpcpp/matrix/batch_struct.hpp +++ b/dpcpp/matrix/batch_struct.hpp @@ -91,7 +91,7 @@ inline batch::matrix::dense::uniform_batch get_batch_struct( * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch +inline batch::matrix::ell::uniform_batch get_batch_struct(const batch::matrix::Ell* const op) { return {op->get_const_values(), diff --git a/dpcpp/preconditioner/batch_identity.hpp.inc b/dpcpp/preconditioner/batch_identity.hpp.inc index 404d987a3f4..792886f845d 100644 --- a/dpcpp/preconditioner/batch_identity.hpp.inc +++ b/dpcpp/preconditioner/batch_identity.hpp.inc @@ -42,15 +42,9 @@ public: static int dynamic_work_size(int, int) { return 0; } - void generate( - size_type batch_id, - const gko::batch::matrix::ell::batch_item&, - ValueType* const, sycl::nd_item<3> item_ct1) - {} - - void generate(size_type batch_id, - const gko::batch::matrix::dense::batch_item&, - ValueType* const, sycl::nd_item<3> item_ct1) + template + void generate(size_type, const batch_item_type&, ValueType*, + sycl::nd_item<3> item_ct1) {} __dpct_inline__ void apply(const int num_rows, const ValueType* const r, diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 33749e91ae4..52f794cfc0e 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -60,9 +60,9 @@ namespace gko { namespace kernels { namespace dpcpp { /** - * @brief The batch Cg solver namespace. + * @brief The batch Bicgstab solver namespace. * - * @ingroup batch_cg + * @ingroup batch_bicgstab */ namespace batch_bicgstab { @@ -77,10 +77,10 @@ template using settings = gko::kernels::batch_bicgstab::settings; -__dpct_inline__ int get_group_size(int value, int simd_len = 32) +__dpct_inline__ int get_group_size(int value, int subgroup_size = 32) { - int num_sg = ceildiv(value, simd_len); - return num_sg * simd_len; + int num_sg = ceildiv(value, subgroup_size); + return num_sg * subgroup_size; } @@ -92,9 +92,9 @@ class KernelCaller { : exec_{std::move(exec)}, settings_{settings} {} - template + template __dpct_inline__ void launch_apply_kernel( const gko::kernels::batch_bicgstab::storage_config& sconf, LogType& logger, PrecType& prec, const BatchMatrixType mat, @@ -111,15 +111,16 @@ class KernelCaller { auto max_iters = settings_.max_iterations; auto res_tol = settings_.residual_tol; - (exec_->get_queue())->submit([&](sycl::handler& cgh) { + exec_->get_queue()->submit([&](sycl::handler& cgh) { sycl::accessor slm_values(sycl::range<1>(shared_size), cgh); cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( - simd_len)]] [[intel::kernel_args_restrict]] { + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + subgroup_size)]] [ + [intel::kernel_args_restrict]] { auto batch_id = item_ct1.get_group_linear_id(); const auto mat_global_entry = gko::batch::matrix::extract_batch_item(mat, batch_id); @@ -162,10 +163,13 @@ class KernelCaller { // reserve 5 for intermediate rho-s, norms, // alpha, omega, temp and for reduce_over_group // If the value available is negative, then set it to 0 - size_type shmem_per_blk = std::max( - device.get_info() - - (group_size + 5) * sizeof(ValueType) - 2 * sizeof(real_type), - static_cast(0)); + const int static_var_mem = + (group_size + 5) * sizeof(ValueType) - 2 * sizeof(real_type); + int shmem_per_blk = std::max( + static_cast( + device.get_info()) - + static_var_mem, + 0); const int padded_num_rows = num_rows; const size_type prec_size = PrecType::dynamic_work_size( padded_num_rows, mat.get_single_item_num_nnz()); @@ -185,16 +189,17 @@ class KernelCaller { int n_shared_total = sconf.n_shared + int(sconf.prec_shared); // template - // launch_apply_kernel - if (num_rows <= 32 && n_shared_total == 10) + // launch_apply_kernel + if (num_rows <= 32 && n_shared_total == 10) { launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); - else if (num_rows <= 256 && n_shared_total == 10) + } else if (num_rows <= 256 && n_shared_total == 10) { launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); - else { + } else { switch (n_shared_total) { case 0: launch_apply_kernel( diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index 67057f80e53..0b6f4511f02 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -39,6 +39,7 @@ __dpct_inline__ void initialize( ValueType& alpha, ValueType* const x_shared_entry, ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry, ValueType* const p_shared_entry, ValueType* const v_shared_entry, + ValueType* const p_hat_shared_entry, typename gko::remove_complex& rhs_norm, typename gko::remove_complex& res_norm, sycl::nd_item<3> item_ct1) @@ -85,6 +86,7 @@ __dpct_inline__ void initialize( for (int iz = tid; iz < num_rows; iz += group_size) { r_hat_shared_entry[iz] = r_shared_entry[iz]; p_shared_entry[iz] = zero(); + p_hat_shared_entry[iz] = zero(); v_shared_entry[iz] = zero(); } } @@ -115,23 +117,24 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, const ValueType* const v_shared_entry, ValueType& alpha, sycl::nd_item<3> item_ct1) { + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); if constexpr (sg_kernel_all) { - auto sg = item_ct1.get_sub_group(); - const auto sg_id = sg.get_group_id(); - const auto tid = item_ct1.get_local_linear_id(); - if (sg_id == 0) { - single_rhs_compute_dot_sg(num_rows, r_hat_shared_entry, - v_shared_entry, alpha, item_ct1); + single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, + v_shared_entry, alpha, item_ct1); } if (tid == 0) { alpha = rho_new / alpha; } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, r_hat_shared_entry, v_shared_entry, - alpha, item_ct1); - alpha = rho_new / alpha; + single_rhs_compute_conj_dot(num_rows, r_hat_shared_entry, + v_shared_entry, alpha, item_ct1); + if (tid == 0) { + alpha = rho_new / alpha; + } } } @@ -158,26 +161,30 @@ __dpct_inline__ void compute_omega(const int num_rows, ValueType& temp, ValueType& omega, sycl::nd_item<3> item_ct1) { + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); if constexpr (sg_kernel_all) { - auto sg = item_ct1.get_sub_group(); - const auto sg_id = sg.get_group_id(); - const auto tid = item_ct1.get_local_linear_id(); - - if (sg_id == 0) - single_rhs_compute_dot_sg(num_rows, t_shared_entry, s_shared_entry, - omega, item_ct1); - else if (sg_id == 1) - single_rhs_compute_dot_sg(num_rows, t_shared_entry, t_shared_entry, - temp, item_ct1); + if (sg_id == 0) { + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, + s_shared_entry, omega, item_ct1); + } else if (sg_id == 1) { + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, + t_shared_entry, temp, item_ct1); + } item_ct1.barrier(sycl::access::fence_space::local_space); - if (tid == 0) omega /= temp; + if (tid == 0) { + omega /= temp; + } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, t_shared_entry, s_shared_entry, omega, - item_ct1); - single_rhs_compute_dot(num_rows, t_shared_entry, t_shared_entry, temp, - item_ct1); - omega /= temp; + single_rhs_compute_conj_dot(num_rows, t_shared_entry, s_shared_entry, + omega, item_ct1); + single_rhs_compute_conj_dot(num_rows, t_shared_entry, t_shared_entry, + temp, item_ct1); + if (tid == 0) { + omega /= temp; + } } } @@ -244,33 +251,21 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, real_type* norms_rhs_sh; real_type* norms_res_sh; - if constexpr (sg_kernel_all) { - using tile_value_t = ValueType[5]; - tile_value_t& values = - *sycl::ext::oneapi::group_local_memory_for_overwrite( - group); - using tile_real_t = real_type[2]; - tile_real_t& reals = - *sycl::ext::oneapi::group_local_memory_for_overwrite( - group); - rho_old_sh = &values[0]; - rho_new_sh = &values[1]; - alpha_sh = &values[2]; - omega_sh = &values[3]; - temp_sh = &values[4]; - norms_rhs_sh = &reals[0]; - norms_res_sh = &reals[1]; - } else { - ValueType values[5]; - real_type reals[2]; - rho_old_sh = &values[0]; - rho_new_sh = &values[1]; - alpha_sh = &values[2]; - omega_sh = &values[3]; - temp_sh = &values[4]; - norms_rhs_sh = &reals[0]; - norms_res_sh = &reals[1]; - } + using tile_value_t = ValueType[5]; + tile_value_t& values = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + group); + using tile_real_t = real_type[2]; + tile_real_t& reals = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + group); + rho_old_sh = &values[0]; + rho_new_sh = &values[1]; + alpha_sh = &values[2]; + omega_sh = &values[3]; + temp_sh = &values[4]; + norms_rhs_sh = &reals[0]; + norms_res_sh = &reals[1]; const int gmem_offset = batch_id * sconf.gmem_stride_bytes / sizeof(ValueType); ValueType* p_hat_sh; @@ -346,11 +341,12 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // compute residual norms // r_hat = r // p = 0 + // p_hat = 0 // v = 0 initialize(num_rows, mat_global_entry, b_global_entry, x_global_entry, rho_old_sh[0], omega_sh[0], - alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, v_sh, - norms_rhs_sh[0], norms_res_sh[0], item_ct1); + alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, p_hat_sh, + v_sh, norms_rhs_sh[0], norms_res_sh[0], item_ct1); item_ct1.barrier(sycl::access::fence_space::local_space); // stopping criterion object @@ -366,13 +362,13 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // rho_new = < r_hat , r > = (r_hat)' * (r) if constexpr (sg_kernel_all) { if (sg_id == 0) { - single_rhs_compute_dot_sg(num_rows, r_hat_sh, r_sh, - rho_new_sh[0], item_ct1); + single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, + rho_new_sh[0], item_ct1); } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], - item_ct1); + single_rhs_compute_conj_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], + item_ct1); } // beta = (rho_new / rho_old)*(alpha / omega) diff --git a/hip/matrix/batch_struct.hip.hpp b/hip/matrix/batch_struct.hip.hpp index e35f13f1249..ba75b1b634e 100644 --- a/hip/matrix/batch_struct.hip.hpp +++ b/hip/matrix/batch_struct.hip.hpp @@ -92,7 +92,8 @@ get_batch_struct(batch::matrix::Dense* const op) * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch, IndexType> +inline batch::matrix::ell::uniform_batch, + const IndexType> get_batch_struct(const batch::matrix::Ell* const op) { return {as_hip_type(op->get_const_values()), diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index f769a4fe6f5..a7c3667c8ef 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -85,21 +85,19 @@ template int get_num_threads_per_block(std::shared_ptr exec, const int num_rows) { - int nwarps = num_rows / 4; - if (nwarps < 2) { - nwarps = 2; - } - const int min_block_size = 2 * config::warp_size; + int num_warps = std::max(num_rows / 4, 2); + constexpr int warp_sz = static_cast(config::warp_size); + const int min_block_size = 2 * warp_sz; const int device_max_threads = - ((std::max(num_rows, min_block_size)) / config::warp_size) * - config::warp_size; + ((std::max(num_rows, min_block_size)) / warp_sz) * warp_sz; const int num_regs_used_per_thread = 64; int max_regs_blk = 0; hipDeviceGetAttribute(&max_regs_blk, hipDeviceAttributeMaxRegistersPerBlock, exec->get_device_id()); const int max_threads_regs = (max_regs_blk / num_regs_used_per_thread); - const int max_threads = std::min(max_threads_regs, device_max_threads); - return std::min(nwarps * static_cast(config::warp_size), max_threads); + int max_threads = std::min(max_threads_regs, device_max_threads); + max_threads = max_threads <= 1024 ? max_threads : 1024; + return std::max(std::min(num_warps * warp_sz, max_threads), min_block_size); } @@ -108,12 +106,12 @@ using settings = gko::kernels::batch_bicgstab::settings; template -class KernelCaller { +class kernel_caller { public: using value_type = HipValueType; - KernelCaller(std::shared_ptr exec, - const settings> settings) + kernel_caller(std::shared_ptr exec, + const settings> settings) : exec_{exec}, settings_{settings} {} @@ -147,12 +145,11 @@ class KernelCaller { const size_type num_batch_items = mat.num_batch_items; constexpr int align_multiple = 8; const int padded_num_rows = - ((mat.num_rows + align_multiple - 1) / align_multiple) * - align_multiple; + ceildiv(mat.num_rows, align_multiple) * align_multiple; const int shmem_per_blk = exec_->get_max_shared_memory_per_block(); const int block_size = get_num_threads_per_block(exec_, mat.num_rows); - assert(block_size >= 2 * config::warp_size); + GKO_ASSERT(block_size >= 2 * config::warp_size); const size_t prec_size = PrecType::dynamic_work_size(padded_num_rows, @@ -175,62 +172,64 @@ class KernelCaller { // Template parameters launch_apply_kernel( + if (sconf.prec_shared) { + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); - else { + } else { switch (sconf.n_shared) { case 0: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 1: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 2: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 3: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 4: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 5: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 6: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 7: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 8: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 9: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; + default: + GKO_NOT_IMPLEMENTED; } } @@ -254,7 +253,7 @@ void apply(std::shared_ptr exec, { using hip_value_type = hip_type; auto dispatcher = batch::solver::create_dispatcher( - KernelCaller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precon); dispatcher.apply(b, x, logdata); } diff --git a/reference/matrix/batch_struct.hpp b/reference/matrix/batch_struct.hpp index bb7680d1493..94beff5c2c2 100644 --- a/reference/matrix/batch_struct.hpp +++ b/reference/matrix/batch_struct.hpp @@ -95,7 +95,7 @@ inline batch::matrix::dense::uniform_batch get_batch_struct( * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch +inline batch::matrix::ell::uniform_batch get_batch_struct(const batch::matrix::Ell* const op) { return {op->get_const_values(), diff --git a/reference/test/solver/batch_bicgstab_kernels.cpp b/reference/test/solver/batch_bicgstab_kernels.cpp index c47c80e64dc..311fb40e5ef 100644 --- a/reference/test/solver/batch_bicgstab_kernels.cpp +++ b/reference/test/solver/batch_bicgstab_kernels.cpp @@ -108,8 +108,8 @@ TYPED_TEST(BatchBicgstab, SolvesStencilSystem) this->linear_system); for (size_t i = 0; i < this->num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i] / - this->linear_system.rhs_norm->get_const_values()[i], + ASSERT_LE(res.host_res_norm->get_const_values()[i] / + this->linear_system.host_rhs_norm->get_const_values()[i], this->solver_settings.residual_tol); } GKO_ASSERT_BATCH_MTX_NEAR(res.x, this->linear_system.exact_sol, @@ -130,9 +130,10 @@ TYPED_TEST(BatchBicgstab, StencilSystemLoggerLogsResidual) auto iter_array = res.log_data->iter_counts.get_const_data(); auto res_log_array = res.log_data->res_norms.get_const_data(); for (size_t i = 0; i < this->num_batch_items; i++) { - ASSERT_LE(res_log_array[i] / this->linear_system.rhs_norm->at(i, 0, 0), - this->solver_settings.residual_tol); - ASSERT_NEAR(res_log_array[i], res.res_norm->get_const_values()[i], + ASSERT_LE( + res_log_array[i] / this->linear_system.host_rhs_norm->at(i, 0, 0), + this->solver_settings.residual_tol); + ASSERT_NEAR(res_log_array[i], res.host_res_norm->get_const_values()[i], 10 * this->eps); } } @@ -186,8 +187,8 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i], + ASSERT_LE(res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i], tol); } } @@ -228,8 +229,8 @@ TYPED_TEST(BatchBicgstab, ApplyLogsResAndIters) auto res_norm = logger->get_residual_norm(); GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); for (size_t i = 0; i < num_batch_items; i++) { - auto rel_res_norm = res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i]; + auto rel_res_norm = res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i]; ASSERT_LE(iter_counts.get_const_data()[i], max_iters); EXPECT_LE(res_norm.get_const_data()[i], tol * 50); ASSERT_LE(rel_res_norm, tol * 50); @@ -266,8 +267,8 @@ TYPED_TEST(BatchBicgstab, CanSolveEllSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i], + ASSERT_LE(res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i], tol * 10); } } @@ -302,6 +303,6 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseHpdSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i], tol * 50); + ASSERT_LE(res.host_res_norm->get_const_values()[i], tol * 50); } } diff --git a/test/solver/batch_bicgstab_kernels.cpp b/test/solver/batch_bicgstab_kernels.cpp index 124dd27640c..ea5e7ec782f 100644 --- a/test/solver/batch_bicgstab_kernels.cpp +++ b/test/solver/batch_bicgstab_kernels.cpp @@ -117,8 +117,8 @@ TEST_F(BatchBicgstab, SolvesStencilSystem) solver_settings, linear_system); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i], + ASSERT_LE(res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i], solver_settings.residual_tol); } GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol); @@ -141,9 +141,9 @@ TEST_F(BatchBicgstab, StencilSystemLoggerLogsResidual) auto res_log_array = res.log_data->res_norms.get_const_data(); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res_log_array[i] / linear_system.rhs_norm->at(i, 0, 0), + ASSERT_LE(res_log_array[i] / linear_system.host_rhs_norm->at(i, 0, 0), solver_settings.residual_tol); - ASSERT_NEAR(res_log_array[i], res.res_norm->get_const_values()[i], + ASSERT_NEAR(res_log_array[i], res.host_res_norm->get_const_values()[i], 10 * tol); } } @@ -185,8 +185,8 @@ TEST_F(BatchBicgstab, CanSolve3ptStencilSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); for (size_t i = 0; i < num_batch_items; i++) { - auto comp_res_norm = res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i]; + auto comp_res_norm = res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i]; ASSERT_LE(comp_res_norm, tol); } } @@ -215,11 +215,11 @@ TEST_F(BatchBicgstab, CanSolveLargeBatchSizeHpdSystem) &logger->get_residual_norm()); GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); for (size_t i = 0; i < num_batch_items; i++) { - auto comp_res_norm = res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i]; + auto comp_res_norm = res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i]; ASSERT_LE(iter_counts->get_const_data()[i], max_iters); EXPECT_LE(res_norm->get_const_data()[i] / - linear_system.rhs_norm->get_const_values()[i], + linear_system.host_rhs_norm->get_const_values()[i], tol); EXPECT_GT(res_norm->get_const_data()[i], real_type{0.0}); ASSERT_LE(comp_res_norm, tol); @@ -250,11 +250,11 @@ TEST_F(BatchBicgstab, CanSolveLargeMatrixSizeHpdSystem) &logger->get_residual_norm()); GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); for (size_t i = 0; i < num_batch_items; i++) { - auto comp_res_norm = res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i]; + auto comp_res_norm = res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i]; ASSERT_LE(iter_counts->get_const_data()[i], max_iters); EXPECT_LE(res_norm->get_const_data()[i] / - linear_system.rhs_norm->get_const_values()[i], + linear_system.host_rhs_norm->get_const_values()[i], tol); EXPECT_GT(res_norm->get_const_data()[i], real_type{0.0}); ASSERT_LE(comp_res_norm, tol); From 84be7dd4acc2b4aee33b09a60be6d094faaa929a Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Wed, 1 Nov 2023 10:07:23 +0100 Subject: [PATCH 11/28] Use synchronize for error handling --- core/base/batch_utilities.hpp | 3 +- cuda/base/exception.cuh | 59 ----------------------- cuda/solver/batch_bicgstab_kernels.cu | 2 +- hip/base/exception.hip.hpp | 58 ---------------------- hip/solver/batch_bicgstab_kernels.hip.cpp | 2 +- 5 files changed, 4 insertions(+), 120 deletions(-) delete mode 100644 cuda/base/exception.cuh delete mode 100644 hip/base/exception.hip.hpp diff --git a/core/base/batch_utilities.hpp b/core/base/batch_utilities.hpp index f05a80322aa..cc92d294173 100644 --- a/core/base/batch_utilities.hpp +++ b/core/base/batch_utilities.hpp @@ -201,8 +201,9 @@ std::unique_ptr read( std::forward(create_args)...); for (size_type b = 0; b < num_batch_items; ++b) { - if (data.at(b).size != data.at(0).size) + if (data.at(b).size != data.at(0).size) { GKO_INVALID_STATE("Incorrect data passed in"); + } tmp->create_view_for_item(b)->read(data[b]); } diff --git a/cuda/base/exception.cuh b/cuda/base/exception.cuh deleted file mode 100644 index ccf74ebdb7b..00000000000 --- a/cuda/base/exception.cuh +++ /dev/null @@ -1,59 +0,0 @@ -/************************************************************* -Copyright (c) 2017-2023, the Ginkgo authors -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: - -1. Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS -IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED -TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*************************************************************/ - -#ifndef GKO_CUDA_BASE_EXCEPTION_CUH_ -#define GKO_CUDA_BASE_EXCEPTION_CUH_ - - -#include - - -namespace gko { - - -#define GKO_CUDA_LAST_IF_ERROR_THROW \ - { \ - cudaError_t err = cudaGetLastError(); \ - if (err != cudaSuccess) { \ - printf(" Kernel error: %s\n", cudaGetErrorString(err)); \ - throw gko::CudaError(__FILE__, __LINE__, __func__, err); \ - } \ - } \ - static_assert(true, \ - "This assert is used to counter the false positive extra " \ - "semi-colon warnings") - - -} // namespace gko - - -#endif // GKO_CUDA_BASE_EXCEPTION_CUH_ diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index 1c26a1754ba..dd8c0487c23 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -265,7 +265,7 @@ public: } } - GKO_CUDA_LAST_IF_ERROR_THROW; + exec_->synchronize(); } private: diff --git a/hip/base/exception.hip.hpp b/hip/base/exception.hip.hpp deleted file mode 100644 index 366f95bffbb..00000000000 --- a/hip/base/exception.hip.hpp +++ /dev/null @@ -1,58 +0,0 @@ -/************************************************************* -Copyright (c) 2017-2023, the Ginkgo authors -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: - -1. Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS -IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED -TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*************************************************************/ - -#ifndef GKO_HIP_BASE_EXCEPTION_HIP_HPP_ -#define GKO_HIP_BASE_EXCEPTION_HIP_HPP_ - - -#include - - -namespace gko { - - -#define GKO_HIP_LAST_IF_ERROR_THROW \ - { \ - hipError_t err = hipGetLastError(); \ - if (err != hipSuccess) { \ - printf(" Hip kernel error: %s\n", hipGetErrorString(err)); \ - throw gko::HipError(__FILE__, __LINE__, __func__, err); \ - } \ - } \ - static_assert(true, \ - "This assert is used to counter the false positive extra " \ - "semi-colon warnings") - - -} // namespace gko - -#endif // GKO_HIP_BASE_EXCEPTION_HIP_HPP_ diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index a7c3667c8ef..a56440a7310 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -233,7 +233,7 @@ class kernel_caller { } } - GKO_HIP_LAST_IF_ERROR_THROW; + exec_->synchronize(); } private: From c6e9543012079c35810611a1c9c0f40f407eb3ec Mon Sep 17 00:00:00 2001 From: ginkgo-bot Date: Wed, 1 Nov 2023 10:59:38 +0000 Subject: [PATCH 12/28] Format files Co-authored-by: Pratik Nayak --- .../base/batch_multi_vector_kernels.hpp.inc | 8 +- cuda/solver/batch_bicgstab_kernels.cu | 1 - dpcpp/base/batch_multi_vector_kernels.dp.cpp | 125 +++++++++--------- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 7 +- hip/solver/batch_bicgstab_kernels.hip.cpp | 1 - 5 files changed, 73 insertions(+), 69 deletions(-) diff --git a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc index 1e0cb3bbcff..cb157d80fd5 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc +++ b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc @@ -105,10 +105,10 @@ __global__ __launch_bounds__( template __device__ __forceinline__ void single_rhs_compute_conj_dot(Group subgroup, - const int num_rows, - const ValueType* x, - const ValueType* y, - ValueType& result) + const int num_rows, + const ValueType* x, + const ValueType* y, + ValueType& result) { ValueType val = zero(); diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index dd8c0487c23..73a74e52172 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -46,7 +46,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/solver/batch_dispatch.hpp" #include "cuda/base/batch_struct.hpp" #include "cuda/base/config.hpp" -#include "cuda/base/exception.cuh" #include "cuda/base/kernel_config.cuh" #include "cuda/base/thrust.cuh" #include "cuda/base/types.hpp" diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index c9809696889..51665d26ff9 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -210,37 +210,41 @@ void compute_dot(std::shared_ptr exec, if (x->get_common_size()[1] == 1) { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto y_b = batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_conj_dot_sg(x_b.num_rows, x_b.values, - y_b.values, res_b.values[0], - item_ct1); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = + batch::extract_batch_item(x_ub, group_id); + const auto y_b = + batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_conj_dot_sg( + x_b.num_rows, x_b.values, y_b.values, + res_b.values[0], item_ct1); + }); }); } else { // TODO: Remove reqd_sub_group size and use sycl::reduce_over_group exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto y_b = batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( - x_b, y_b, res_b, item_ct1, - [](auto val) { return val; }); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = + batch::extract_batch_item(x_ub, group_id); + const auto y_b = + batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return val; }); + }); }); } } @@ -274,18 +278,19 @@ void compute_conj_dot(std::shared_ptr exec, exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto y_b = batch::extract_batch_item(y_ub, group_id); - const auto res_b = batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( - x_b, y_b, res_b, item_ct1, - [](auto val) { return conj(val); }); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return conj(val); }); + }); }); } @@ -317,31 +322,33 @@ void compute_norm2(std::shared_ptr exec, if (x->get_common_size()[1] == 1) { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, - res_b.values[0], item_ct1); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = + batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, + res_b.values[0], item_ct1); + }); }); } else { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_norm2_kernel(x_b, res_b, item_ct1); - }); + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = + batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_norm2_kernel(x_b, res_b, item_ct1); + }); }); } } diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 52f794cfc0e..fa47352457e 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -117,10 +117,9 @@ class KernelCaller { slm_values(sycl::range<1>(shared_size), cgh); cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( - subgroup_size)]] [ - [intel::kernel_args_restrict]] { + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + subgroup_size)]] [[intel::kernel_args_restrict]] { auto batch_id = item_ct1.get_group_linear_id(); const auto mat_global_entry = gko::batch::matrix::extract_batch_item(mat, batch_id); diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index a56440a7310..e042b6137d8 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -47,7 +47,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/solver/batch_dispatch.hpp" #include "hip/base/batch_struct.hip.hpp" #include "hip/base/config.hip.hpp" -#include "hip/base/exception.hip.hpp" #include "hip/base/math.hip.hpp" #include "hip/base/thrust.hip.hpp" #include "hip/base/types.hip.hpp" From 1054b7b00207a57abae533c533c97af017fcfdac Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Wed, 1 Nov 2023 16:09:35 +0100 Subject: [PATCH 13/28] Add scoped cuda shmem config --- .../{kernel_config.cuh => kernel_config.hpp} | 62 ++++++++++++++----- cuda/solver/batch_bicgstab_kernels.cu | 6 +- 2 files changed, 50 insertions(+), 18 deletions(-) rename cuda/base/{kernel_config.cuh => kernel_config.hpp} (53%) diff --git a/cuda/base/kernel_config.cuh b/cuda/base/kernel_config.hpp similarity index 53% rename from cuda/base/kernel_config.cuh rename to cuda/base/kernel_config.hpp index 6280753bcda..1fbc0d6e4d8 100644 --- a/cuda/base/kernel_config.cuh +++ b/cuda/base/kernel_config.hpp @@ -30,36 +30,66 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *************************************************************/ -#ifndef GKO_CUDA_BASE_KERNEL_CONFIG_CUH_ -#define GKO_CUDA_BASE_KERNEL_CONFIG_CUH_ +#ifndef GKO_CUDA_BASE_KERNEL_CONFIG_HPP_ +#define GKO_CUDA_BASE_KERNEL_CONFIG_HPP_ -#include "cuda/base/math.hpp" +#include namespace gko { namespace kernels { namespace cuda { +namespace detail { -/** - * Set shared memory bank configuration. - * - * \tparam ValueType The scalar type used for computations. - */ template -inline void configure_shared_memory_banks() -{ - if (sizeof(ValueType) == 4) { - cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeFourByte); - } else if (sizeof(ValueType) % 8 == 0) { - cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte); +class shared_memory_config_guard { +public: + using value_type = ValueType; + shared_memory_config_guard() : original_config_{} + { + GKO_ASSERT_NO_CUDA_ERRORS( + cudaDeviceGetSharedMemConfig(&original_config_)); + + if (sizeof(value_type) == 4) { + GKO_ASSERT_NO_CUDA_ERRORS( + cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeFourByte)); + } else if (sizeof(value_type) % 8 == 0) { + GKO_ASSERT_NO_CUDA_ERRORS( + cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte)); + } else { + GKO_ASSERT_NO_CUDA_ERRORS( + cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeDefault)); + } + } + + + ~shared_memory_config_guard() + { + auto error_code = cudaDeviceSetSharedMemConfig(original_config_); + if (error_code != cudaSuccess) { +#if GKO_VERBOSE_LEVEL >= 1 + std::cerr << "Unrecoverable CUDA error while resetting the " + "shared memory config to " + << original_config_ << " in " << __func__ << ": " + << cudaGetErrorName(error_code) << ": " + << cudaGetErrorString(error_code) << std::endl + << "Exiting program" << std::endl; +#endif // GKO_VERBOSE_LEVEL >= 1 + std::exit(error_code); + } } -} + +private: + cudaSharedMemConfig original_config_; +}; +} // namespace detail } // namespace cuda } // namespace kernels } // namespace gko -#endif // GKO_CUDA_BASE_KERNEL_CONFIG_CUH_ + +#endif // GKO_CUDA_BASE_KERNEL_CONFIG_HPP_ diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index 73a74e52172..16df7e7e55e 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -46,7 +46,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/solver/batch_dispatch.hpp" #include "cuda/base/batch_struct.hpp" #include "cuda/base/config.hpp" -#include "cuda/base/kernel_config.cuh" +#include "cuda/base/kernel_config.hpp" #include "cuda/base/thrust.cuh" #include "cuda/base/types.hpp" #include "cuda/components/cooperative_groups.cuh" @@ -172,7 +172,9 @@ public: constexpr int align_multiple = 8; const int padded_num_rows = ceildiv(mat.num_rows, align_multiple) * align_multiple; - gko::kernels::cuda::configure_shared_memory_banks(); + auto shem_guard = + gko::kernels::cuda::detail::shared_memory_config_guard< + value_type>(); const int shmem_per_blk = get_max_dynamic_shared_memory(exec_); From cc22557b9bb927425ece230dcf472dd68a95c13c Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Wed, 1 Nov 2023 17:45:10 +0100 Subject: [PATCH 14/28] move max_shmem query to internal --- cuda/base/executor.cpp | 3 --- hip/base/executor.hip.cpp | 3 --- hip/solver/batch_bicgstab_kernels.hip.cpp | 10 +++++++--- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/cuda/base/executor.cpp b/cuda/base/executor.cpp index 01880127641..f296fb9da86 100644 --- a/cuda/base/executor.cpp +++ b/cuda/base/executor.cpp @@ -258,9 +258,6 @@ void CudaExecutor::set_gpu_property() kernels::cuda::config::warp_size; this->get_exec_info().max_subgroup_size = kernels::cuda::config::warp_size; - GKO_ASSERT_NO_CUDA_ERRORS(cudaDeviceGetAttribute( - &this->get_exec_info().max_shared_memory_per_workgroup, - cudaDevAttrMaxSharedMemoryPerBlock, this->get_device_id())); } } diff --git a/hip/base/executor.hip.cpp b/hip/base/executor.hip.cpp index 489e9b28ff9..8d175c0e424 100644 --- a/hip/base/executor.hip.cpp +++ b/hip/base/executor.hip.cpp @@ -262,9 +262,6 @@ void HipExecutor::set_gpu_property() #endif // GINKGO_HIP_PLATFORM_NVCC this->get_exec_info().max_subgroup_size = kernels::hip::config::warp_size; - GKO_ASSERT_NO_HIP_ERRORS(hipDeviceGetAttribute( - &this->get_exec_info().max_shared_memory_per_workgroup, - hipDeviceAttributeMaxSharedMemoryPerBlock, this->get_device_id())); } } diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index e042b6137d8..4a04317ca9d 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -91,8 +91,9 @@ int get_num_threads_per_block(std::shared_ptr exec, ((std::max(num_rows, min_block_size)) / warp_sz) * warp_sz; const int num_regs_used_per_thread = 64; int max_regs_blk = 0; - hipDeviceGetAttribute(&max_regs_blk, hipDeviceAttributeMaxRegistersPerBlock, - exec->get_device_id()); + GKO_ASSERT_NO_HIP_ERRORS(hipDeviceGetAttribute( + &max_regs_blk, hipDeviceAttributeMaxRegistersPerBlock, + exec->get_device_id())); const int max_threads_regs = (max_regs_blk / num_regs_used_per_thread); int max_threads = std::min(max_threads_regs, device_max_threads); max_threads = max_threads <= 1024 ? max_threads : 1024; @@ -145,7 +146,10 @@ class kernel_caller { constexpr int align_multiple = 8; const int padded_num_rows = ceildiv(mat.num_rows, align_multiple) * align_multiple; - const int shmem_per_blk = exec_->get_max_shared_memory_per_block(); + int shmem_per_blk = 0; + GKO_ASSERT_NO_HIP_ERRORS(hipDeviceGetAttribute( + &shmem_per_blk, hipDeviceAttributeMaxSharedMemoryPerBlock, + exec_->get_device_id())); const int block_size = get_num_threads_per_block(exec_, mat.num_rows); GKO_ASSERT(block_size >= 2 * config::warp_size); From 501c4e71fd4dd302e2cc970e928fd4288002a04d Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Thu, 2 Nov 2023 16:08:21 +0100 Subject: [PATCH 15/28] Update size_type in tests --- test/matrix/batch_dense_kernels.cpp | 6 +++--- test/matrix/batch_ell_kernels.cpp | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/matrix/batch_dense_kernels.cpp b/test/matrix/batch_dense_kernels.cpp index 1f3967b0eb8..fa75a8f61e4 100644 --- a/test/matrix/batch_dense_kernels.cpp +++ b/test/matrix/batch_dense_kernels.cpp @@ -71,9 +71,9 @@ class Dense : public CommonTestFixture { std::normal_distribution<>(-1.0, 1.0), rand_engine, ref); } - void set_up_apply_data(int num_rows, gko::size_type num_vecs = 1) + void set_up_apply_data(gko::size_type num_rows, gko::size_type num_vecs = 1) { - const int num_cols = 32; + const gko::size_type num_cols = 32; mat = gen_mtx(batch_size, num_rows, num_cols); y = gen_mtx(batch_size, num_cols, num_vecs); alpha = gen_mtx(batch_size, 1, 1); @@ -91,7 +91,7 @@ class Dense : public CommonTestFixture { std::default_random_engine rand_engine; - const size_t batch_size = 11; + const gko::size_type batch_size = 11; std::unique_ptr mat; std::unique_ptr y; std::unique_ptr alpha; diff --git a/test/matrix/batch_ell_kernels.cpp b/test/matrix/batch_ell_kernels.cpp index 572f47ba47d..7a4c6558c5d 100644 --- a/test/matrix/batch_ell_kernels.cpp +++ b/test/matrix/batch_ell_kernels.cpp @@ -87,8 +87,8 @@ class Ell : public CommonTestFixture { void set_up_apply_data(gko::size_type num_vecs = 1, int num_elems_per_row = 5) { - const int num_rows = 252; - const int num_cols = 32; + const gko::size_type num_rows = 252; + const gko::size_type num_cols = 32; GKO_ASSERT(num_elems_per_row <= num_cols); mat = gen_mtx(batch_size, num_rows, num_cols, num_elems_per_row); y = gen_mvec(batch_size, num_cols, num_vecs); @@ -107,7 +107,7 @@ class Ell : public CommonTestFixture { std::ranlux48 rand_engine; - const size_t batch_size = 11; + const gko::size_type batch_size = 11; std::unique_ptr mat; std::unique_ptr y; std::unique_ptr alpha; From 7b0ebfd75c0e00be874bd794da4ee9d73574f38c Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Thu, 2 Nov 2023 16:24:39 +0100 Subject: [PATCH 16/28] Update contributors.txt Co-authored-by: Phuong Nguyen --- contributors.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/contributors.txt b/contributors.txt index 1f1259bc082..aec120d93dd 100644 --- a/contributors.txt +++ b/contributors.txt @@ -20,6 +20,7 @@ Kashi Aditya Karlsruhe Institute of Technology Koch Marcel Karlsruhe Institute of Technology Maier Matthias Texas A&M University Nayak Pratik Karlsruhe Institute of Technology +Nguyen Phuong University of Tennessee, Knoxville Olenik Gregor HPSim Ribizel Tobias Karlsruhe Institute of Technology Riemer Lukas Karlsruhe Institute of Technology From f1babfd354121621b7560ab4f885b8780341f04e Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Thu, 2 Nov 2023 16:58:15 +0100 Subject: [PATCH 17/28] review updates Co-authored-by: Yu-Hsiang Tsai --- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 14 ++++++++------ dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 2 ++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index fa47352457e..978ab94d9c4 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -77,7 +77,8 @@ template using settings = gko::kernels::batch_bicgstab::settings; -__dpct_inline__ int get_group_size(int value, int subgroup_size = 32) +__dpct_inline__ int get_group_size(int value, + int subgroup_size = config::warp_size) { int num_sg = ceildiv(value, subgroup_size); return num_sg * subgroup_size; @@ -117,9 +118,10 @@ class KernelCaller { slm_values(sycl::range<1>(shared_size), cgh); cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( - subgroup_size)]] [[intel::kernel_args_restrict]] { + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + subgroup_size)]] [ + [intel::kernel_args_restrict]] { auto batch_id = item_ct1.get_group_linear_id(); const auto mat_global_entry = gko::batch::matrix::extract_batch_item(mat, batch_id); @@ -163,7 +165,7 @@ class KernelCaller { // alpha, omega, temp and for reduce_over_group // If the value available is negative, then set it to 0 const int static_var_mem = - (group_size + 5) * sizeof(ValueType) - 2 * sizeof(real_type); + (group_size + 5) * sizeof(ValueType) + 2 * sizeof(real_type); int shmem_per_blk = std::max( static_cast( device.get_info()) - @@ -191,7 +193,7 @@ class KernelCaller { // launch_apply_kernel if (num_rows <= 32 && n_shared_total == 10) { - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); } else if (num_rows <= 256 && n_shared_total == 10) { diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index 0b6f4511f02..e71eb060afa 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -259,6 +259,8 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, tile_real_t& reals = *sycl::ext::oneapi::group_local_memory_for_overwrite( group); + // ValueType values[5]; + // real_type reals[2]; rho_old_sh = &values[0]; rho_new_sh = &values[1]; alpha_sh = &values[2]; From 221bba9d0b6b70ca843da4278d8c8410bd3db4ad Mon Sep 17 00:00:00 2001 From: ginkgo-bot Date: Thu, 2 Nov 2023 16:02:16 +0000 Subject: [PATCH 18/28] Format files Co-authored-by: Pratik Nayak --- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 978ab94d9c4..839cb9e0976 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -118,10 +118,9 @@ class KernelCaller { slm_values(sycl::range<1>(shared_size), cgh); cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( - subgroup_size)]] [ - [intel::kernel_args_restrict]] { + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + subgroup_size)]] [[intel::kernel_args_restrict]] { auto batch_id = item_ct1.get_group_linear_id(); const auto mat_global_entry = gko::batch::matrix::extract_batch_item(mat, batch_id); From aa026c1b2bce8e7a579965a5998ee53a4ac0ea42 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Thu, 2 Nov 2023 17:54:26 +0100 Subject: [PATCH 19/28] dpcpp group size and doc fixes --- common/cuda_hip/log/batch_logger.hpp.inc | 2 +- core/solver/batch_bicgstab_kernels.hpp | 2 +- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 7 ++++++- hip/solver/batch_bicgstab_kernels.hip.cpp | 3 +++ include/ginkgo/core/stop/batch_stop_enum.hpp | 2 +- test/solver/batch_bicgstab_kernels.cpp | 8 ++++---- 6 files changed, 16 insertions(+), 8 deletions(-) diff --git a/common/cuda_hip/log/batch_logger.hpp.inc b/common/cuda_hip/log/batch_logger.hpp.inc index 7a4d59b67e9..e8cf77960ef 100644 --- a/common/cuda_hip/log/batch_logger.hpp.inc +++ b/common/cuda_hip/log/batch_logger.hpp.inc @@ -36,7 +36,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. template class SimpleFinalLogger final { public: - using real_type = remove_complex; + using real_type = RealType; SimpleFinalLogger(real_type* const batch_residuals, int* const batch_iters) : final_residuals_{batch_residuals}, final_iters_{batch_iters} diff --git a/core/solver/batch_bicgstab_kernels.hpp b/core/solver/batch_bicgstab_kernels.hpp index 6f5de2e770c..32291562afd 100644 --- a/core/solver/batch_bicgstab_kernels.hpp +++ b/core/solver/batch_bicgstab_kernels.hpp @@ -192,7 +192,7 @@ storage_config compute_shared_storage(const int available_shared_mem, sconf.prec_shared = true; rem_shared -= prec_storage; } - // Set the global storage config and align to 32 bytes. + // Set the global storage config and align to align_bytes bytes. set_gmem_stride_bytes(sconf, vec_size, prec_storage); return sconf; } diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 839cb9e0976..9da926c7c58 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -154,11 +154,16 @@ class KernelCaller { GKO_ASSERT(num_rhs == 1); auto device = exec_->get_queue()->get_device(); - auto group_size = + auto max_group_size = + device.get_info(); + int group_size = device.get_info(); if (group_size > num_rows) { group_size = get_group_size(num_rows); }; + group_size = std::min( + std::max(group_size, static_cast(2 * config::warp_size)), + static_cast(max_group_size)); // reserve 5 for intermediate rho-s, norms, // alpha, omega, temp and for reduce_over_group diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index 4a04317ca9d..fbd6543574f 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -89,6 +89,9 @@ int get_num_threads_per_block(std::shared_ptr exec, const int min_block_size = 2 * warp_sz; const int device_max_threads = ((std::max(num_rows, min_block_size)) / warp_sz) * warp_sz; + // This value has been taken from RocM docs. This is the number of registers + // that maximizes the occupancy on an AMD GPU (MI200). HIP does not have an + // API to query the number of registers a function uses. const int num_regs_used_per_thread = 64; int max_regs_blk = 0; GKO_ASSERT_NO_HIP_ERRORS(hipDeviceGetAttribute( diff --git a/include/ginkgo/core/stop/batch_stop_enum.hpp b/include/ginkgo/core/stop/batch_stop_enum.hpp index 1694dd164d9..3c463b8730c 100644 --- a/include/ginkgo/core/stop/batch_stop_enum.hpp +++ b/include/ginkgo/core/stop/batch_stop_enum.hpp @@ -48,7 +48,7 @@ namespace stop { * * With the `relative` tolerance type, the solver * convergence criteria checks against the relative residual norm - * ($||r|| \leq ||b|| \times \tau$, where $||b||$$ is the L2 norm of the rhs). + * ($||r|| \leq ||b|| \times \tau$, where $||b||$ is the L2 norm of the rhs). * * @note the computed residual norm, $||r||$ may be implicit or explicit * depending on the solver algorithm. diff --git a/test/solver/batch_bicgstab_kernels.cpp b/test/solver/batch_bicgstab_kernels.cpp index ea5e7ec782f..f99e7a469d0 100644 --- a/test/solver/batch_bicgstab_kernels.cpp +++ b/test/solver/batch_bicgstab_kernels.cpp @@ -198,7 +198,7 @@ TEST_F(BatchBicgstab, CanSolveLargeBatchSizeHpdSystem) const int num_rows = 102; const int num_rhs = 1; const real_type tol = 1e-5; - const int max_iters = num_rows; + const int max_iters = num_rows * 2; std::shared_ptr logger = Logger::create(); auto mat = gko::share(gko::test::generate_diag_dominant_batch_matrix( exec, num_batch_items, num_rows, true)); @@ -213,7 +213,7 @@ TEST_F(BatchBicgstab, CanSolveLargeBatchSizeHpdSystem) &logger->get_num_iterations()); auto res_norm = gko::make_temporary_clone(exec->get_master(), &logger->get_residual_norm()); - GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); + GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 500); for (size_t i = 0; i < num_batch_items; i++) { auto comp_res_norm = res.host_res_norm->get_const_values()[i] / linear_system.host_rhs_norm->get_const_values()[i]; @@ -233,7 +233,7 @@ TEST_F(BatchBicgstab, CanSolveLargeMatrixSizeHpdSystem) const int num_rows = 1025; const int num_rhs = 1; const real_type tol = 1e-5; - const int max_iters = num_rows; + const int max_iters = num_rows * 2; std::shared_ptr logger = Logger::create(); auto mat = gko::share(gko::test::generate_diag_dominant_batch_matrix( exec, num_batch_items, num_rows, true)); @@ -248,7 +248,7 @@ TEST_F(BatchBicgstab, CanSolveLargeMatrixSizeHpdSystem) &logger->get_num_iterations()); auto res_norm = gko::make_temporary_clone(exec->get_master(), &logger->get_residual_norm()); - GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); + GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 500); for (size_t i = 0; i < num_batch_items; i++) { auto comp_res_norm = res.host_res_norm->get_const_values()[i] / linear_system.host_rhs_norm->get_const_values()[i]; From 79e5cad63d24cd130e38bf6b21b9ed09902a1675 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Fri, 3 Nov 2023 12:10:35 +0100 Subject: [PATCH 20/28] use global_and_local barrier --- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 40 +++++++++++---------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index e71eb060afa..3efb93e664b 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -60,13 +60,13 @@ __dpct_inline__ void initialize( x_shared_entry[iz] = x_global_entry[iz]; r_shared_entry[iz] = b_global_entry[iz]; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // r = b - A*x advanced_apply_kernel(static_cast(-1.0), mat_global_entry, x_shared_entry, static_cast(1.0), r_shared_entry, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); if constexpr (sg_kernel_all) { if (sg_id == 0) { @@ -80,7 +80,7 @@ __dpct_inline__ void initialize( single_rhs_compute_norm2(num_rows, r_shared_entry, res_norm, item_ct1); single_rhs_compute_norm2(num_rows, b_global_entry, rhs_norm, item_ct1); } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); for (int iz = tid; iz < num_rows; iz += group_size) { @@ -125,10 +125,11 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, v_shared_entry, alpha, item_ct1); } + item_ct1.barrier(sycl::access::fence_space::global_and_local); if (tid == 0) { alpha = rho_new / alpha; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); } else { single_rhs_compute_conj_dot(num_rows, r_hat_shared_entry, v_shared_entry, alpha, item_ct1); @@ -172,11 +173,11 @@ __dpct_inline__ void compute_omega(const int num_rows, single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, t_shared_entry, temp, item_ct1); } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); if (tid == 0) { omega /= temp; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); } else { single_rhs_compute_conj_dot(num_rows, t_shared_entry, s_shared_entry, omega, item_ct1); @@ -349,7 +350,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, x_global_entry, rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0], item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // stopping criterion object StopType stop(tol, norms_rhs_sh); @@ -367,7 +368,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, rho_new_sh[0], item_ct1); } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); } else { single_rhs_compute_conj_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], item_ct1); @@ -377,24 +378,24 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // p = r + beta*(p - omega * v) update_p(num_rows, rho_new_sh[0], rho_old_sh[0], alpha_sh[0], omega_sh[0], r_sh, v_sh, p_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // p_hat = precond * p prec_shared.apply(num_rows, p_sh, p_hat_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // v = A * p_hat simple_apply_kernel(mat_global_entry, p_hat_sh, v_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // alpha = rho_new / < r_hat , v> compute_alpha(num_rows, rho_new_sh[0], r_hat_sh, v_sh, alpha_sh[0], item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // s = r - alpha*v update_s(num_rows, r_sh, alpha_sh[0], v_sh, s_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // an estimate of residual norms if constexpr (sg_kernel_all) { @@ -402,7 +403,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], item_ct1); } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); } else { single_rhs_compute_norm2(num_rows, s_sh, norms_res_sh[0], item_ct1); } @@ -415,22 +416,22 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // s_hat = precond * s prec_shared.apply(num_rows, s_sh, s_hat_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // t = A * s_hat simple_apply_kernel(mat_global_entry, s_hat_sh, t_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // omega = / compute_omega(num_rows, t_sh, s_sh, temp_sh[0], omega_sh[0], item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // x = x + alpha*p_hat + omega *s_hat // r = s - omega * t update_x_and_r(num_rows, p_hat_sh, s_hat_sh, alpha_sh[0], omega_sh[0], s_sh, t_sh, x_sh, r_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); if constexpr (sg_kernel_all) { if (sg_id == 0) @@ -439,7 +440,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, if (tid == group_size - 1) { rho_old_sh[0] = rho_new_sh[0]; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); } else { single_rhs_compute_norm2(num_rows, r_sh, norms_res_sh[0], item_ct1); rho_old_sh[0] = rho_new_sh[0]; @@ -450,4 +451,5 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // copy x back to global memory copy_kernel(num_rows, x_sh, x_global_entry, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); } From 693d308bf9b55eccb76489b62546450d729e302d Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Fri, 3 Nov 2023 19:56:38 +0100 Subject: [PATCH 21/28] Fix Intel2020 apply call issue --- core/matrix/batch_ell.cpp | 13 +++++++++++-- include/ginkgo/core/solver/batch_solver_base.hpp | 14 ++++++++++++-- reference/test/solver/batch_bicgstab_kernels.cpp | 2 +- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/core/matrix/batch_ell.cpp b/core/matrix/batch_ell.cpp index 19b2dcae5c3..88863a05dd4 100644 --- a/core/matrix/batch_ell.cpp +++ b/core/matrix/batch_ell.cpp @@ -134,7 +134,10 @@ Ell* Ell::apply( ptr_param> b, ptr_param> x) { - static_cast(this)->apply(b, x); + this->validate_application_parameters(b.get(), x.get()); + auto exec = this->get_executor(); + this->apply_impl(make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, x).get()); return this; } @@ -159,7 +162,13 @@ Ell* Ell::apply( ptr_param> beta, ptr_param> x) { - static_cast(this)->apply(alpha, b, beta, x); + this->validate_application_parameters(alpha.get(), b.get(), beta.get(), + x.get()); + auto exec = this->get_executor(); + this->apply_impl(make_temporary_clone(exec, alpha).get(), + make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, beta).get(), + make_temporary_clone(exec, x).get()); return this; } diff --git a/include/ginkgo/core/solver/batch_solver_base.hpp b/include/ginkgo/core/solver/batch_solver_base.hpp index 3141812e259..8cc5c67837a 100644 --- a/include/ginkgo/core/solver/batch_solver_base.hpp +++ b/include/ginkgo/core/solver/batch_solver_base.hpp @@ -277,6 +277,7 @@ class EnableBatchSolver public EnableBatchLinOp { public: using real_type = remove_complex; + const ConcreteSolver* apply(ptr_param> b, ptr_param> x) const { @@ -305,7 +306,10 @@ class EnableBatchSolver ConcreteSolver* apply(ptr_param> b, ptr_param> x) { - static_cast(this)->apply(b, x); + this->validate_application_parameters(b.get(), x.get()); + auto exec = this->get_executor(); + this->apply_impl(make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, x).get()); return self(); } @@ -314,7 +318,13 @@ class EnableBatchSolver ptr_param> beta, ptr_param> x) { - static_cast(this)->apply(alpha, b, beta, x); + this->validate_application_parameters(alpha.get(), b.get(), beta.get(), + x.get()); + auto exec = this->get_executor(); + this->apply_impl(make_temporary_clone(exec, alpha).get(), + make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, beta).get(), + make_temporary_clone(exec, x).get()); return self(); } diff --git a/reference/test/solver/batch_bicgstab_kernels.cpp b/reference/test/solver/batch_bicgstab_kernels.cpp index 311fb40e5ef..211318e8a8f 100644 --- a/reference/test/solver/batch_bicgstab_kernels.cpp +++ b/reference/test/solver/batch_bicgstab_kernels.cpp @@ -87,7 +87,7 @@ class BatchBicgstab : public ::testing::Test { std::shared_ptr exec; const real_type eps = 1e-3; const gko::size_type num_batch_items = 2; - const int num_rows = 3; + const int num_rows = 15; const int num_rhs = 1; const Settings solver_settings{100, eps, gko::batch::stop::tolerance_type::relative}; From 705339ec658df8d53652a30160f90e5877dd0631 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Fri, 3 Nov 2023 20:42:21 +0100 Subject: [PATCH 22/28] Fix diag_dominance and tol issue --- core/test/utils/batch_helpers.hpp | 15 ++++++-- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 42 ++++++++++----------- test/solver/batch_bicgstab_kernels.cpp | 4 +- 3 files changed, 34 insertions(+), 27 deletions(-) diff --git a/core/test/utils/batch_helpers.hpp b/core/test/utils/batch_helpers.hpp index 43da4cd9d54..eee31050505 100644 --- a/core/test/utils/batch_helpers.hpp +++ b/core/test/utils/batch_helpers.hpp @@ -166,7 +166,7 @@ std::unique_ptr generate_diag_dominant_batch_matrix( static_cast(num_cols)}, {}}; auto engine = std::default_random_engine(42); - auto rand_diag_dist = std::normal_distribution(4.0, 12.0); + auto rand_diag_dist = std::normal_distribution(8.0, 1.0); for (int row = 0; row < num_rows; ++row) { std::uniform_int_distribution rand_nnz_dist{1, row + 1}; const auto k = rand_nnz_dist(engine); @@ -175,8 +175,8 @@ std::unique_ptr generate_diag_dominant_batch_matrix( } data.nonzeros.emplace_back( row, row, - static_cast( - detail::get_rand_value(rand_diag_dist, engine))); + std::abs(static_cast( + detail::get_rand_value(rand_diag_dist, engine)))); if (row < num_rows - 1) { data.nonzeros.emplace_back(row, k, value_type{-1.0}); data.nonzeros.emplace_back(row, row + 1, value_type{-1.0}); @@ -208,8 +208,15 @@ std::unique_ptr generate_diag_dominant_batch_matrix( auto rand_data = fill_random_matrix_data( num_rows, num_cols, row_idxs, col_idxs, rand_val_dist, engine); gko::utils::make_diag_dominant(rand_data); - batch_data.emplace_back(rand_data); GKO_ASSERT(rand_data.size == batch_data.at(0).size); + GKO_ASSERT(rand_data.nonzeros.size() == data.nonzeros.size()); + // Copy over the diagonal values + for (int i = 0; i < data.nonzeros.size(); ++i) { + if (data.nonzeros[i].row == data.nonzeros[i].column) { + rand_data.nonzeros[i] = data.nonzeros[i]; + } + } + batch_data.emplace_back(rand_data); } return gko::batch::read( exec, batch_data, std::forward(args)...); diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index 3efb93e664b..4e29ab32886 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -60,13 +60,13 @@ __dpct_inline__ void initialize( x_shared_entry[iz] = x_global_entry[iz]; r_shared_entry[iz] = b_global_entry[iz]; } - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); // r = b - A*x advanced_apply_kernel(static_cast(-1.0), mat_global_entry, x_shared_entry, static_cast(1.0), r_shared_entry, item_ct1); - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); if constexpr (sg_kernel_all) { if (sg_id == 0) { @@ -80,7 +80,7 @@ __dpct_inline__ void initialize( single_rhs_compute_norm2(num_rows, r_shared_entry, res_norm, item_ct1); single_rhs_compute_norm2(num_rows, b_global_entry, rhs_norm, item_ct1); } - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); for (int iz = tid; iz < num_rows; iz += group_size) { @@ -125,11 +125,11 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, v_shared_entry, alpha, item_ct1); } - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); if (tid == 0) { alpha = rho_new / alpha; } - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); } else { single_rhs_compute_conj_dot(num_rows, r_hat_shared_entry, v_shared_entry, alpha, item_ct1); @@ -173,11 +173,11 @@ __dpct_inline__ void compute_omega(const int num_rows, single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, t_shared_entry, temp, item_ct1); } - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); if (tid == 0) { omega /= temp; } - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); } else { single_rhs_compute_conj_dot(num_rows, t_shared_entry, s_shared_entry, omega, item_ct1); @@ -350,7 +350,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, x_global_entry, rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0], item_ct1); - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); // stopping criterion object StopType stop(tol, norms_rhs_sh); @@ -368,7 +368,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, rho_new_sh[0], item_ct1); } - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); } else { single_rhs_compute_conj_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], item_ct1); @@ -378,24 +378,24 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // p = r + beta*(p - omega * v) update_p(num_rows, rho_new_sh[0], rho_old_sh[0], alpha_sh[0], omega_sh[0], r_sh, v_sh, p_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); // p_hat = precond * p prec_shared.apply(num_rows, p_sh, p_hat_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); // v = A * p_hat simple_apply_kernel(mat_global_entry, p_hat_sh, v_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); // alpha = rho_new / < r_hat , v> compute_alpha(num_rows, rho_new_sh[0], r_hat_sh, v_sh, alpha_sh[0], item_ct1); - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); // s = r - alpha*v update_s(num_rows, r_sh, alpha_sh[0], v_sh, s_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); // an estimate of residual norms if constexpr (sg_kernel_all) { @@ -403,7 +403,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], item_ct1); } - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); } else { single_rhs_compute_norm2(num_rows, s_sh, norms_res_sh[0], item_ct1); } @@ -416,22 +416,22 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // s_hat = precond * s prec_shared.apply(num_rows, s_sh, s_hat_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); // t = A * s_hat simple_apply_kernel(mat_global_entry, s_hat_sh, t_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); // omega = / compute_omega(num_rows, t_sh, s_sh, temp_sh[0], omega_sh[0], item_ct1); - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); // x = x + alpha*p_hat + omega *s_hat // r = s - omega * t update_x_and_r(num_rows, p_hat_sh, s_hat_sh, alpha_sh[0], omega_sh[0], s_sh, t_sh, x_sh, r_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); if constexpr (sg_kernel_all) { if (sg_id == 0) @@ -440,7 +440,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, if (tid == group_size - 1) { rho_old_sh[0] = rho_new_sh[0]; } - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); } else { single_rhs_compute_norm2(num_rows, r_sh, norms_res_sh[0], item_ct1); rho_old_sh[0] = rho_new_sh[0]; @@ -451,5 +451,5 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // copy x back to global memory copy_kernel(num_rows, x_sh, x_global_entry, item_ct1); - item_ct1.barrier(sycl::access::fence_space::global_and_local); + item_ct1.barrier(sycl::access::fence_space::local_space); } diff --git a/test/solver/batch_bicgstab_kernels.cpp b/test/solver/batch_bicgstab_kernels.cpp index f99e7a469d0..4bec19a165f 100644 --- a/test/solver/batch_bicgstab_kernels.cpp +++ b/test/solver/batch_bicgstab_kernels.cpp @@ -222,7 +222,7 @@ TEST_F(BatchBicgstab, CanSolveLargeBatchSizeHpdSystem) linear_system.host_rhs_norm->get_const_values()[i], tol); EXPECT_GT(res_norm->get_const_data()[i], real_type{0.0}); - ASSERT_LE(comp_res_norm, tol); + ASSERT_LE(comp_res_norm, tol * 10); } } @@ -257,6 +257,6 @@ TEST_F(BatchBicgstab, CanSolveLargeMatrixSizeHpdSystem) linear_system.host_rhs_norm->get_const_values()[i], tol); EXPECT_GT(res_norm->get_const_data()[i], real_type{0.0}); - ASSERT_LE(comp_res_norm, tol); + ASSERT_LE(comp_res_norm, tol * 10); } } From 6729f6842bd3c388e38bde987a5b607cc7d722d2 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Sat, 4 Nov 2023 09:09:39 +0100 Subject: [PATCH 23/28] Fix some include issues --- cuda/base/kernel_config.hpp | 1 + reference/log/batch_logger.hpp | 8 +++----- reference/preconditioner/batch_identity.hpp | 1 + 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cuda/base/kernel_config.hpp b/cuda/base/kernel_config.hpp index 1fbc0d6e4d8..b8b4f621f06 100644 --- a/cuda/base/kernel_config.hpp +++ b/cuda/base/kernel_config.hpp @@ -35,6 +35,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include namespace gko { diff --git a/reference/log/batch_logger.hpp b/reference/log/batch_logger.hpp index a70af0af51c..2598c23766f 100644 --- a/reference/log/batch_logger.hpp +++ b/reference/log/batch_logger.hpp @@ -51,8 +51,6 @@ namespace batch_log { template class SimpleFinalLogger final { public: - using real_type = remove_complex; - /** * Constructor * @@ -61,7 +59,7 @@ class SimpleFinalLogger final { * @param batch_iters final iteration counts for each * linear system in the batch. */ - SimpleFinalLogger(real_type* const batch_residuals, int* const batch_iters) + SimpleFinalLogger(RealType* const batch_residuals, int* const batch_iters) : final_residuals_{batch_residuals}, final_iters_{batch_iters} {} @@ -73,14 +71,14 @@ class SimpleFinalLogger final { * @param res_norm Norm of final residual norm */ void log_iteration(const size_type batch_idx, const int iter, - const real_type res_norm) + const RealType res_norm) { final_iters_[batch_idx] = iter; final_residuals_[batch_idx] = res_norm; } private: - real_type* const final_residuals_; + RealType* const final_residuals_; int* const final_iters_; }; diff --git a/reference/preconditioner/batch_identity.hpp b/reference/preconditioner/batch_identity.hpp index b0bf869c6be..6d6d462e660 100644 --- a/reference/preconditioner/batch_identity.hpp +++ b/reference/preconditioner/batch_identity.hpp @@ -34,6 +34,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define GKO_REFERENCE_PRECONDITIONER_BATCH_IDENTITY_HPP_ +#include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" From eebc06a201938e4d69a97d5992c8936b941f1efd Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Sat, 4 Nov 2023 18:37:26 +0100 Subject: [PATCH 24/28] Review updates Co-authored-by: Yu-Hsiang Tsai --- cuda/base/kernel_config.hpp | 16 ++++------------ cuda/solver/batch_bicgstab_kernels.cu | 6 ++---- dpcpp/base/batch_multi_vector_kernels.hpp.inc | 3 ++- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 2 -- hip/solver/batch_bicgstab_kernels.hip.cpp | 4 +--- 5 files changed, 9 insertions(+), 22 deletions(-) diff --git a/cuda/base/kernel_config.hpp b/cuda/base/kernel_config.hpp index b8b4f621f06..a4aecea1d55 100644 --- a/cuda/base/kernel_config.hpp +++ b/cuda/base/kernel_config.hpp @@ -35,6 +35,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include + + #include @@ -68,18 +70,8 @@ class shared_memory_config_guard { ~shared_memory_config_guard() { - auto error_code = cudaDeviceSetSharedMemConfig(original_config_); - if (error_code != cudaSuccess) { -#if GKO_VERBOSE_LEVEL >= 1 - std::cerr << "Unrecoverable CUDA error while resetting the " - "shared memory config to " - << original_config_ << " in " << __func__ << ": " - << cudaGetErrorName(error_code) << ": " - << cudaGetErrorString(error_code) << std::endl - << "Exiting program" << std::endl; -#endif // GKO_VERBOSE_LEVEL >= 1 - std::exit(error_code); - } + // No need to exit or throw if we cant set the value back. + cudaDeviceSetSharedMemConfig(original_config_); } private: diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index 16df7e7e55e..1d80f206c1b 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -205,11 +205,11 @@ public: // Template parameters launch_apply_kernel - if (sconf.prec_shared) + if (sconf.prec_shared) { launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); - else { + } else { switch (sconf.n_shared) { case 0: launch_apply_kernel( @@ -265,8 +265,6 @@ public: GKO_NOT_IMPLEMENTED; } } - - exec_->synchronize(); } private: diff --git a/dpcpp/base/batch_multi_vector_kernels.hpp.inc b/dpcpp/base/batch_multi_vector_kernels.hpp.inc index 1fb5684871d..be9d02aa88d 100644 --- a/dpcpp/base/batch_multi_vector_kernels.hpp.inc +++ b/dpcpp/base/batch_multi_vector_kernels.hpp.inc @@ -163,8 +163,9 @@ __dpct_inline__ void single_rhs_compute_norm2_sg( using real_type = typename gko::remove_complex; real_type val = zero(); - for (int r = subgroup.get_local_id(); r < num_rows; r += subgroup_size) + for (int r = subgroup.get_local_id(); r < num_rows; r += subgroup_size) { val += squared_norm(x[r]); + } val = ::gko::kernels::dpcpp::reduce( subg, val, [](real_type a, real_type b) { return a + b; }); diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index 4e29ab32886..4be5040d4ea 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -260,8 +260,6 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, tile_real_t& reals = *sycl::ext::oneapi::group_local_memory_for_overwrite( group); - // ValueType values[5]; - // real_type reals[2]; rho_old_sh = &values[0]; rho_new_sh = &values[1]; alpha_sh = &values[2]; diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index fbd6543574f..217d314a5c9 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -89,7 +89,7 @@ int get_num_threads_per_block(std::shared_ptr exec, const int min_block_size = 2 * warp_sz; const int device_max_threads = ((std::max(num_rows, min_block_size)) / warp_sz) * warp_sz; - // This value has been taken from RocM docs. This is the number of registers + // This value has been taken from ROCm docs. This is the number of registers // that maximizes the occupancy on an AMD GPU (MI200). HIP does not have an // API to query the number of registers a function uses. const int num_regs_used_per_thread = 64; @@ -238,8 +238,6 @@ class kernel_caller { GKO_NOT_IMPLEMENTED; } } - - exec_->synchronize(); } private: From 498512cb8932b12b55724ab41150c22b0d69586f Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Sat, 4 Nov 2023 22:03:00 +0100 Subject: [PATCH 25/28] use fence_space::global_and_local --- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 42 ++++++++++----------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index 4be5040d4ea..636227973a8 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -60,13 +60,13 @@ __dpct_inline__ void initialize( x_shared_entry[iz] = x_global_entry[iz]; r_shared_entry[iz] = b_global_entry[iz]; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // r = b - A*x advanced_apply_kernel(static_cast(-1.0), mat_global_entry, x_shared_entry, static_cast(1.0), r_shared_entry, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); if constexpr (sg_kernel_all) { if (sg_id == 0) { @@ -80,7 +80,7 @@ __dpct_inline__ void initialize( single_rhs_compute_norm2(num_rows, r_shared_entry, res_norm, item_ct1); single_rhs_compute_norm2(num_rows, b_global_entry, rhs_norm, item_ct1); } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); for (int iz = tid; iz < num_rows; iz += group_size) { @@ -125,11 +125,11 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, v_shared_entry, alpha, item_ct1); } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); if (tid == 0) { alpha = rho_new / alpha; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); } else { single_rhs_compute_conj_dot(num_rows, r_hat_shared_entry, v_shared_entry, alpha, item_ct1); @@ -173,11 +173,11 @@ __dpct_inline__ void compute_omega(const int num_rows, single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, t_shared_entry, temp, item_ct1); } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); if (tid == 0) { omega /= temp; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); } else { single_rhs_compute_conj_dot(num_rows, t_shared_entry, s_shared_entry, omega, item_ct1); @@ -348,7 +348,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, x_global_entry, rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0], item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // stopping criterion object StopType stop(tol, norms_rhs_sh); @@ -366,7 +366,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, rho_new_sh[0], item_ct1); } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); } else { single_rhs_compute_conj_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], item_ct1); @@ -376,24 +376,24 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // p = r + beta*(p - omega * v) update_p(num_rows, rho_new_sh[0], rho_old_sh[0], alpha_sh[0], omega_sh[0], r_sh, v_sh, p_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // p_hat = precond * p prec_shared.apply(num_rows, p_sh, p_hat_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // v = A * p_hat simple_apply_kernel(mat_global_entry, p_hat_sh, v_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // alpha = rho_new / < r_hat , v> compute_alpha(num_rows, rho_new_sh[0], r_hat_sh, v_sh, alpha_sh[0], item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // s = r - alpha*v update_s(num_rows, r_sh, alpha_sh[0], v_sh, s_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // an estimate of residual norms if constexpr (sg_kernel_all) { @@ -401,7 +401,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], item_ct1); } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); } else { single_rhs_compute_norm2(num_rows, s_sh, norms_res_sh[0], item_ct1); } @@ -414,22 +414,22 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // s_hat = precond * s prec_shared.apply(num_rows, s_sh, s_hat_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // t = A * s_hat simple_apply_kernel(mat_global_entry, s_hat_sh, t_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // omega = / compute_omega(num_rows, t_sh, s_sh, temp_sh[0], omega_sh[0], item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); // x = x + alpha*p_hat + omega *s_hat // r = s - omega * t update_x_and_r(num_rows, p_hat_sh, s_hat_sh, alpha_sh[0], omega_sh[0], s_sh, t_sh, x_sh, r_sh, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); if constexpr (sg_kernel_all) { if (sg_id == 0) @@ -438,7 +438,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, if (tid == group_size - 1) { rho_old_sh[0] = rho_new_sh[0]; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); } else { single_rhs_compute_norm2(num_rows, r_sh, norms_res_sh[0], item_ct1); rho_old_sh[0] = rho_new_sh[0]; @@ -449,5 +449,5 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // copy x back to global memory copy_kernel(num_rows, x_sh, x_global_entry, item_ct1); - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::global_and_local); } From 1bc6d8382755c28c08b081192aac3756f6af31f4 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Sun, 5 Nov 2023 15:30:06 +0100 Subject: [PATCH 26/28] Use updated deferred factory macros. --- .../ginkgo/core/solver/batch_solver_base.hpp | 56 ++++--------------- 1 file changed, 11 insertions(+), 45 deletions(-) diff --git a/include/ginkgo/core/solver/batch_solver_base.hpp b/include/ginkgo/core/solver/batch_solver_base.hpp index 8cc5c67837a..8f534753bf8 100644 --- a/include/ginkgo/core/solver/batch_solver_base.hpp +++ b/include/ginkgo/core/solver/batch_solver_base.hpp @@ -182,25 +182,13 @@ class BatchSolver { * excluding the parameters available in iterative_solver_factory_parameters. * @see GKO_CREATE_FACTORY_PARAMETERS */ -struct preconditioned_iterative_solver_factory_parameters { - /** - * The preconditioner to be used by the iterative solver. By default, no - * preconditioner is used. - */ - std::shared_ptr preconditioner{nullptr}; - - /** - * Already generated preconditioner. If one is provided, the factory - * `preconditioner` will be ignored. - */ - std::shared_ptr generated_preconditioner{nullptr}; -}; +struct preconditioned_iterative_solver_factory_parameters {}; template struct enable_preconditioned_iterative_solver_factory_parameters - : enable_parameters_type, - preconditioned_iterative_solver_factory_parameters { + : enable_parameters_type { + using parameters_type = Parameters; /** * Default maximum number iterations allowed. * @@ -225,40 +213,18 @@ struct enable_preconditioned_iterative_solver_factory_parameters tolerance_type, ::gko::batch::stop::tolerance_type::absolute); /** - * Provides a preconditioner factory to be used by the iterative solver in a - * fluent interface. - * @see preconditioned_iterative_solver_factory_parameters::preconditioner + * The preconditioner to be used by the iterative solver. By default, no + * preconditioner is used. */ - Parameters& with_preconditioner( - deferred_factory_parameter preconditioner) - { - this->preconditioner_generator = std::move(preconditioner); - this->deferred_factories["preconditioner"] = [](const auto& exec, - auto& params) { - if (!params.preconditioner_generator.is_empty()) { - params.preconditioner = - params.preconditioner_generator.on(exec); - } - }; - return *self(); - } + std::shared_ptr GKO_DEFERRED_FACTORY_PARAMETER( + preconditioner); /** - * Provides a concrete preconditioner to be used by the iterative solver in - * a fluent interface. - * @see preconditioned_iterative_solver_factory_parameters::preconditioner + * Already generated preconditioner. If one is provided, the factory + * `preconditioner` will be ignored. */ - Parameters& with_generated_preconditioner( - std::shared_ptr generated_preconditioner) - { - this->generated_preconditioner = std::move(generated_preconditioner); - return *self(); - } - -private: - GKO_ENABLE_SELF(Parameters); - - deferred_factory_parameter preconditioner_generator; + std::shared_ptr GKO_FACTORY_PARAMETER_SCALAR( + generated_preconditioner, nullptr); }; From 79e68b36550c5518e45863c36372e742ad0d1c23 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Sun, 5 Nov 2023 16:59:31 +0100 Subject: [PATCH 27/28] Review updates Co-authored-by: Yu-Hsiang Tsai --- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 39 ++--- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 140 +++++++----------- .../ginkgo/core/solver/batch_solver_base.hpp | 9 -- include/ginkgo/core/solver/solver_base.hpp | 3 - 4 files changed, 70 insertions(+), 121 deletions(-) diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 9da926c7c58..8f0a334e6ac 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -94,8 +94,8 @@ class KernelCaller { {} template + const int n_shared_total, typename PrecType, typename LogType, + typename BatchMatrixType> __dpct_inline__ void launch_apply_kernel( const gko::kernels::batch_bicgstab::storage_config& sconf, LogType& logger, PrecType& prec, const BatchMatrixType mat, @@ -118,9 +118,10 @@ class KernelCaller { slm_values(sycl::range<1>(shared_size), cgh); cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( - subgroup_size)]] [[intel::kernel_args_restrict]] { + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + subgroup_size)]] [ + [intel::kernel_args_restrict]] { auto batch_id = item_ct1.get_group_linear_id(); const auto mat_global_entry = gko::batch::matrix::extract_batch_item(mat, batch_id); @@ -130,7 +131,7 @@ class KernelCaller { ValueType* const x_global_entry = gko::batch::multi_vector::batch_item_ptr( x_values, 1, num_rows, batch_id); - apply_kernel( + apply_kernel( sconf, max_iters, res_tol, logger, prec, mat_global_entry, b_global_entry, x_global_entry, num_rows, mat.get_single_item_num_nnz(), @@ -197,67 +198,67 @@ class KernelCaller { // launch_apply_kernel if (num_rows <= 32 && n_shared_total == 10) { - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); } else if (num_rows <= 256 && n_shared_total == 10) { - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); } else { switch (n_shared_total) { case 0: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 1: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 2: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 3: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 4: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 5: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 6: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 7: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 8: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 9: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; case 10: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); break; diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index 636227973a8..03f8ea31165 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -30,8 +30,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *************************************************************/ -template +template __dpct_inline__ void initialize( const int num_rows, const BatchMatrixType_entry& mat_global_entry, const ValueType* const b_global_entry, @@ -68,17 +67,12 @@ __dpct_inline__ void initialize( r_shared_entry, item_ct1); item_ct1.barrier(sycl::access::fence_space::global_and_local); - if constexpr (sg_kernel_all) { - if (sg_id == 0) { - single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm, - item_ct1); - } else if (sg_id == 1) { - single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm, - item_ct1); - } - } else { - single_rhs_compute_norm2(num_rows, r_shared_entry, res_norm, item_ct1); - single_rhs_compute_norm2(num_rows, b_global_entry, rhs_norm, item_ct1); + if (sg_id == 0) { + single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm, + item_ct1); + } else if (sg_id == 1) { + single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm, + item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); @@ -111,7 +105,7 @@ __dpct_inline__ void update_p(const int num_rows, const ValueType& rho_new, } -template +template __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, const ValueType* const r_hat_shared_entry, const ValueType* const v_shared_entry, @@ -120,23 +114,15 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, auto sg = item_ct1.get_sub_group(); const auto sg_id = sg.get_group_id(); const auto tid = item_ct1.get_local_linear_id(); - if constexpr (sg_kernel_all) { - if (sg_id == 0) { - single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, - v_shared_entry, alpha, item_ct1); - } - item_ct1.barrier(sycl::access::fence_space::global_and_local); - if (tid == 0) { - alpha = rho_new / alpha; - } - item_ct1.barrier(sycl::access::fence_space::global_and_local); - } else { - single_rhs_compute_conj_dot(num_rows, r_hat_shared_entry, - v_shared_entry, alpha, item_ct1); - if (tid == 0) { - alpha = rho_new / alpha; - } + if (sg_id == 0) { + single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, + v_shared_entry, alpha, item_ct1); } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + if (tid == 0) { + alpha = rho_new / alpha; + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); } @@ -155,7 +141,7 @@ __dpct_inline__ void update_s(const int num_rows, } -template +template __dpct_inline__ void compute_omega(const int num_rows, const ValueType* const t_shared_entry, const ValueType* const s_shared_entry, @@ -165,28 +151,18 @@ __dpct_inline__ void compute_omega(const int num_rows, auto sg = item_ct1.get_sub_group(); const auto sg_id = sg.get_group_id(); const auto tid = item_ct1.get_local_linear_id(); - if constexpr (sg_kernel_all) { - if (sg_id == 0) { - single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, - s_shared_entry, omega, item_ct1); - } else if (sg_id == 1) { - single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, - t_shared_entry, temp, item_ct1); - } - item_ct1.barrier(sycl::access::fence_space::global_and_local); - if (tid == 0) { - omega /= temp; - } - item_ct1.barrier(sycl::access::fence_space::global_and_local); - } else { - single_rhs_compute_conj_dot(num_rows, t_shared_entry, s_shared_entry, - omega, item_ct1); - single_rhs_compute_conj_dot(num_rows, t_shared_entry, t_shared_entry, - temp, item_ct1); - if (tid == 0) { - omega /= temp; - } + if (sg_id == 0) { + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, s_shared_entry, + omega, item_ct1); + } else if (sg_id == 1) { + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, t_shared_entry, + temp, item_ct1); + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + if (tid == 0) { + omega /= temp; } + item_ct1.barrier(sycl::access::fence_space::global_and_local); } @@ -220,9 +196,8 @@ __dpct_inline__ void update_x_middle(const int num_rows, const ValueType& alpha, } -template +template void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, const int max_iter, const gko::remove_complex tol, LogType logger, PrecType prec_shared, @@ -344,10 +319,10 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // p = 0 // p_hat = 0 // v = 0 - initialize(num_rows, mat_global_entry, b_global_entry, - x_global_entry, rho_old_sh[0], omega_sh[0], - alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, p_hat_sh, - v_sh, norms_rhs_sh[0], norms_res_sh[0], item_ct1); + initialize(num_rows, mat_global_entry, b_global_entry, x_global_entry, + rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh, r_hat_sh, + p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0], + item_ct1); item_ct1.barrier(sycl::access::fence_space::global_and_local); // stopping criterion object @@ -361,16 +336,11 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, } // rho_new = < r_hat , r > = (r_hat)' * (r) - if constexpr (sg_kernel_all) { - if (sg_id == 0) { - single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, - rho_new_sh[0], item_ct1); - } - item_ct1.barrier(sycl::access::fence_space::global_and_local); - } else { - single_rhs_compute_conj_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], - item_ct1); + if (sg_id == 0) { + single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, + rho_new_sh[0], item_ct1); } + item_ct1.barrier(sycl::access::fence_space::global_and_local); // beta = (rho_new / rho_old)*(alpha / omega) // p = r + beta*(p - omega * v) @@ -387,8 +357,8 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, item_ct1.barrier(sycl::access::fence_space::global_and_local); // alpha = rho_new / < r_hat , v> - compute_alpha(num_rows, rho_new_sh[0], r_hat_sh, v_sh, - alpha_sh[0], item_ct1); + compute_alpha(num_rows, rho_new_sh[0], r_hat_sh, v_sh, alpha_sh[0], + item_ct1); item_ct1.barrier(sycl::access::fence_space::global_and_local); // s = r - alpha*v @@ -396,15 +366,11 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, item_ct1.barrier(sycl::access::fence_space::global_and_local); // an estimate of residual norms - if constexpr (sg_kernel_all) { - if (sg_id == 0) { - single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], - item_ct1); - } - item_ct1.barrier(sycl::access::fence_space::global_and_local); - } else { - single_rhs_compute_norm2(num_rows, s_sh, norms_res_sh[0], item_ct1); + if (sg_id == 0) { + single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], + item_ct1); } + item_ct1.barrier(sycl::access::fence_space::global_and_local); if (stop.check_converged(norms_res_sh)) { update_x_middle(num_rows, alpha_sh[0], p_hat_sh, x_sh, item_ct1); @@ -421,8 +387,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, item_ct1.barrier(sycl::access::fence_space::global_and_local); // omega = / - compute_omega(num_rows, t_sh, s_sh, temp_sh[0], - omega_sh[0], item_ct1); + compute_omega(num_rows, t_sh, s_sh, temp_sh[0], omega_sh[0], item_ct1); item_ct1.barrier(sycl::access::fence_space::global_and_local); // x = x + alpha*p_hat + omega *s_hat @@ -431,18 +396,13 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, s_sh, t_sh, x_sh, r_sh, item_ct1); item_ct1.barrier(sycl::access::fence_space::global_and_local); - if constexpr (sg_kernel_all) { - if (sg_id == 0) - single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0], - item_ct1); - if (tid == group_size - 1) { - rho_old_sh[0] = rho_new_sh[0]; - } - item_ct1.barrier(sycl::access::fence_space::global_and_local); - } else { - single_rhs_compute_norm2(num_rows, r_sh, norms_res_sh[0], item_ct1); + if (sg_id == 0) + single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0], + item_ct1); + if (tid == group_size - 1) { rho_old_sh[0] = rho_new_sh[0]; } + item_ct1.barrier(sycl::access::fence_space::global_and_local); } logger.log_iteration(batch_id, iter, norms_res_sh[0]); diff --git a/include/ginkgo/core/solver/batch_solver_base.hpp b/include/ginkgo/core/solver/batch_solver_base.hpp index 8f534753bf8..cd4ae8d1590 100644 --- a/include/ginkgo/core/solver/batch_solver_base.hpp +++ b/include/ginkgo/core/solver/batch_solver_base.hpp @@ -177,18 +177,9 @@ class BatchSolver { }; -/** - * The parameter type shared between all preconditioned iterative solvers, - * excluding the parameters available in iterative_solver_factory_parameters. - * @see GKO_CREATE_FACTORY_PARAMETERS - */ -struct preconditioned_iterative_solver_factory_parameters {}; - - template struct enable_preconditioned_iterative_solver_factory_parameters : enable_parameters_type { - using parameters_type = Parameters; /** * Default maximum number iterations allowed. * diff --git a/include/ginkgo/core/solver/solver_base.hpp b/include/ginkgo/core/solver/solver_base.hpp index cd0043c7b44..070cc4e6b4a 100644 --- a/include/ginkgo/core/solver/solver_base.hpp +++ b/include/ginkgo/core/solver/solver_base.hpp @@ -856,7 +856,6 @@ class EnablePreconditionedIterativeSolver template struct enable_iterative_solver_factory_parameters : enable_parameters_type { - using parameters_type = Parameters; /** * Stopping criteria to be used by the solver. */ @@ -868,8 +867,6 @@ struct enable_iterative_solver_factory_parameters template struct enable_preconditioned_iterative_solver_factory_parameters : enable_iterative_solver_factory_parameters { - using parameters_type = Parameters; - /** * The preconditioner to be used by the iterative solver. By default, no * preconditioner is used. From a1b84d4b6b2ce80689d7340d1f1ec2880d9796e6 Mon Sep 17 00:00:00 2001 From: ginkgo-bot Date: Sun, 5 Nov 2023 16:20:56 +0000 Subject: [PATCH 28/28] Format files Co-authored-by: Pratik Nayak --- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 8f0a334e6ac..9e353734f36 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -118,10 +118,9 @@ class KernelCaller { slm_values(sycl::range<1>(shared_size), cgh); cgh.parallel_for( - sycl_nd_range(grid, block), [= - ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( - subgroup_size)]] [ - [intel::kernel_args_restrict]] { + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + subgroup_size)]] [[intel::kernel_args_restrict]] { auto batch_id = item_ct1.get_group_linear_id(); const auto mat_global_entry = gko::batch::matrix::extract_batch_item(mat, batch_id);