From fb50eaffd7823138148c5b8700ba72ecd118a9a5 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Mon, 30 Oct 2023 12:13:02 +0100 Subject: [PATCH] Review updates Co-authored-by: Yu-Hsiang Tsai Co-authored-by: Marcel Koch --- .../base/batch_multi_vector_kernels.hpp.inc | 2 +- .../preconditioner/batch_identity.hpp.inc | 13 +- .../solver/batch_bicgstab_kernels.hpp.inc | 26 ++-- core/solver/batch_bicgstab_kernels.hpp | 18 +-- core/test/utils/batch_helpers.hpp | 18 +-- cuda/matrix/batch_struct.hpp | 3 +- cuda/solver/batch_bicgstab_kernels.cu | 17 +-- dpcpp/base/batch_multi_vector_kernels.dp.cpp | 137 +++++++++--------- dpcpp/base/batch_multi_vector_kernels.hpp.inc | 13 +- dpcpp/matrix/batch_struct.hpp | 2 +- dpcpp/preconditioner/batch_identity.hpp.inc | 12 +- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 45 +++--- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 112 +++++++------- hip/matrix/batch_struct.hip.hpp | 3 +- hip/solver/batch_bicgstab_kernels.hip.cpp | 57 ++++---- reference/matrix/batch_struct.hpp | 2 +- .../test/solver/batch_bicgstab_kernels.cpp | 25 ++-- test/solver/batch_bicgstab_kernels.cpp | 24 +-- 18 files changed, 254 insertions(+), 275 deletions(-) diff --git a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc index 72d58ecf5b3..1e0cb3bbcff 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc +++ b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc @@ -104,7 +104,7 @@ __global__ __launch_bounds__( template -__device__ __forceinline__ void single_rhs_compute_dot(Group subgroup, +__device__ __forceinline__ void single_rhs_compute_conj_dot(Group subgroup, const int num_rows, const ValueType* x, const ValueType* y, diff --git a/common/cuda_hip/preconditioner/batch_identity.hpp.inc b/common/cuda_hip/preconditioner/batch_identity.hpp.inc index 1b1fb7b5482..923ed4ce946 100644 --- a/common/cuda_hip/preconditioner/batch_identity.hpp.inc +++ b/common/cuda_hip/preconditioner/batch_identity.hpp.inc @@ -45,16 +45,9 @@ public: return 0; } - __device__ __forceinline__ void generate( - size_type, - const gko::batch::matrix::ell::batch_item&, - ValueType*) - {} - - __device__ __forceinline__ void generate( - size_type, - const gko::batch::matrix::dense::batch_item&, - ValueType*) + template + __device__ __forceinline__ void generate(size_type, const batch_item_type&, + ValueType*) {} __device__ __forceinline__ void apply(const int num_rows, diff --git a/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc index 0f666f205e8..faee2e069a7 100644 --- a/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc +++ b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc @@ -38,7 +38,8 @@ __device__ __forceinline__ void initialize( const ValueType* const x_global_entry, ValueType& rho_old, ValueType& omega, ValueType& alpha, ValueType* const x_shared_entry, ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry, - ValueType* const p_shared_entry, ValueType* const v_shared_entry, + ValueType* const p_shared_entry, ValueType* const p_hat_shared_entry, + ValueType* const v_shared_entry, typename gko::remove_complex& rhs_norm, typename gko::remove_complex& res_norm) { @@ -70,6 +71,7 @@ __device__ __forceinline__ void initialize( for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) { r_hat_shared_entry[iz] = r_shared_entry[iz]; p_shared_entry[iz] = zero(); + p_hat_shared_entry[iz] = zero(); v_shared_entry[iz] = zero(); } } @@ -82,8 +84,8 @@ __device__ __forceinline__ void update_p( const ValueType* const r_shared_entry, const ValueType* const v_shared_entry, ValueType* const p_shared_entry) { + const ValueType beta = (rho_new / rho_old) * (alpha / omega); for (int r = threadIdx.x; r < num_rows; r += blockDim.x) { - const ValueType beta = (rho_new / rho_old) * (alpha / omega); p_shared_entry[r] = r_shared_entry[r] + beta * (p_shared_entry[r] - omega * v_shared_entry[r]); @@ -97,8 +99,8 @@ __device__ __forceinline__ void compute_alpha( const ValueType* const v_shared_entry, ValueType& alpha) { if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_dot(subgroup, num_rows, r_hat_shared_entry, - v_shared_entry, alpha); + single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry, + v_shared_entry, alpha); } __syncthreads(); if (threadIdx.x == 0) { @@ -126,11 +128,11 @@ __device__ __forceinline__ void compute_omega( const ValueType* const s_shared_entry, ValueType& temp, ValueType& omega) { if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_dot(subgroup, num_rows, t_shared_entry, - s_shared_entry, omega); + single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry, + s_shared_entry, omega); } else if (threadIdx.x / config::warp_size == 1) { - single_rhs_compute_dot(subgroup, num_rows, t_shared_entry, - t_shared_entry, temp); + single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry, + t_shared_entry, temp); } __syncthreads(); @@ -278,10 +280,12 @@ __global__ void apply_kernel( // compute residual norms // r_hat = r // p = 0 + // p_hat = 0 // v = 0 initialize(subgroup, num_rows, mat_entry, b_entry_ptr, x_gl_entry_ptr, rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh, - r_hat_sh, p_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0]); + r_hat_sh, p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], + norms_res_sh[0]); __syncthreads(); // stopping criterion object @@ -296,8 +300,8 @@ __global__ void apply_kernel( // rho_new = < r_hat , r > = (r_hat)' * (r) if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_dot(subgroup, num_rows, r_hat_sh, r_sh, - rho_new_sh[0]); + single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh, r_sh, + rho_new_sh[0]); } __syncthreads(); diff --git a/core/solver/batch_bicgstab_kernels.hpp b/core/solver/batch_bicgstab_kernels.hpp index b3a5faf0a49..f43a7f2ddd5 100644 --- a/core/solver/batch_bicgstab_kernels.hpp +++ b/core/solver/batch_bicgstab_kernels.hpp @@ -115,8 +115,7 @@ void set_gmem_stride_bytes(storage_config& sconf, } // align global memory chunks sconf.gmem_stride_bytes = - gmem_stride > 0 ? ((gmem_stride - 1) / align_bytes + 1) * align_bytes - : 0; + gmem_stride > 0 ? ceildiv(gmem_stride, align_bytes) * align_bytes : 0; } @@ -143,8 +142,8 @@ void set_gmem_stride_bytes(storage_config& sconf, * - rhs_norms * - res_norms * - * @param shared_mem_per_blk The amount of shared memory per block to use for - * keeping intermediate vectors. In case keeping the matrix in L1 cache etc. + * @param available_shared_mem The amount of shared memory per block to use + * for keeping intermediate vectors. In case keeping the matrix in L1 cache etc. * should be prioritized, the cache configuration must be updated separately * and the needed space should be subtracted before passing to this * function. @@ -154,7 +153,7 @@ void set_gmem_stride_bytes(storage_config& sconf, * @return A struct containing allocation information specific to Bicgstab. */ template -storage_config compute_shared_storage(const int shared_mem_per_blk, +storage_config compute_shared_storage(const int available_shared_mem, const int num_rows, const int num_nz, const int num_rhs) { @@ -163,10 +162,11 @@ storage_config compute_shared_storage(const int shared_mem_per_blk, const int num_main_vecs = 9; const int prec_storage = Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType); - int rem_shared = shared_mem_per_blk; - // Set default values. All vecs are in global. + int rem_shared = available_shared_mem; + // Set default values. Initially all vecs are in global memory. + // {prec_shared, n_shared, n_global, gmem_stride_bytes, padded_vec_len} storage_config sconf{false, 0, num_main_vecs, 0, num_rows}; - // If available shared mem, is zero, set all vecs to global. + // If available shared mem is zero, set all vecs to global. if (rem_shared <= 0) { set_gmem_stride_bytes(sconf, vec_size, prec_storage); return sconf; @@ -177,13 +177,13 @@ storage_config compute_shared_storage(const int shared_mem_per_blk, const int num_vecs_shared = min(initial_vecs_available, num_main_vecs); sconf.n_shared += num_vecs_shared; sconf.n_global -= num_vecs_shared; + rem_shared -= num_vecs_shared * vec_size; // Set the storage configuration with preconditioner workspace in global if // there are any vectors in global memory. if (sconf.n_global > 0) { set_gmem_stride_bytes(sconf, vec_size, prec_storage); return sconf; } - rem_shared -= num_vecs_shared * vec_size; // If more shared memory space is available and preconditioner workspace is // needed, enable preconditioner workspace to use shared memory. if (rem_shared >= prec_storage && prec_storage > 0) { diff --git a/core/test/utils/batch_helpers.hpp b/core/test/utils/batch_helpers.hpp index 5dfca81ffff..d5bbd77fa08 100644 --- a/core/test/utils/batch_helpers.hpp +++ b/core/test/utils/batch_helpers.hpp @@ -224,7 +224,7 @@ struct LinearSystem { std::shared_ptr matrix; std::shared_ptr rhs; - std::shared_ptr rhs_norm; + std::shared_ptr host_rhs_norm; std::shared_ptr exact_sol; }; @@ -250,8 +250,8 @@ LinearSystem generate_batch_linear_system( // A * x_{exact} = b sys.matrix->apply(sys.exact_sol, sys.rhs); const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs)); - sys.rhs_norm = real_vec::create(exec->get_master(), norm_dim); - sys.rhs->compute_norm2(sys.rhs_norm.get()); + sys.host_rhs_norm = real_vec::create(exec->get_master(), norm_dim); + sys.rhs->compute_norm2(sys.host_rhs_norm.get()); return sys; } @@ -273,13 +273,13 @@ compute_residual_norms( const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs)); auto residual_vec = b->clone(); - auto res_norms = real_vec::create(exec->get_master(), norm_dim); + auto res_norm = real_vec::create(exec->get_master(), norm_dim); auto alpha = gko::batch::initialize(num_batch_items, {-1.0}, exec); auto beta = gko::batch::initialize(num_batch_items, {1.0}, exec); mtx->apply(alpha, x, beta, residual_vec); - residual_vec->compute_norm2(res_norms); - return res_norms; + residual_vec->compute_norm2(res_norm); + return res_norm; } @@ -289,7 +289,7 @@ struct Result { using real_vec = batch::MultiVector>; std::shared_ptr x; - std::shared_ptr res_norm; + std::shared_ptr host_res_norm; }; @@ -323,7 +323,7 @@ Result solve_linear_system( result.x->fill(zero()); solver->apply(sys.rhs, result.x); - result.res_norm = + result.host_res_norm = compute_residual_norms(sys.matrix.get(), sys.rhs.get(), result.x.get()); return std::move(result); @@ -369,7 +369,7 @@ ResultWithLogData solve_linear_system( result.log_data->iter_counts = log_data->iter_counts; result.log_data->res_norms = log_data->res_norms; - result.res_norm = + result.host_res_norm = compute_residual_norms(sys.matrix.get(), sys.rhs.get(), result.x.get()); return std::move(result); diff --git a/cuda/matrix/batch_struct.hpp b/cuda/matrix/batch_struct.hpp index 4a2a1835961..55a30c043e3 100644 --- a/cuda/matrix/batch_struct.hpp +++ b/cuda/matrix/batch_struct.hpp @@ -92,7 +92,8 @@ get_batch_struct(batch::matrix::Dense* const op) * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch, IndexType> +inline batch::matrix::ell::uniform_batch, + const IndexType> get_batch_struct(const batch::matrix::Ell* const op) { return {as_cuda_type(op->get_const_values()), diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index 9ecb27aecf2..1c26a1754ba 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -101,13 +101,10 @@ int get_num_threads_per_block(std::shared_ptr exec, cudaDeviceGetAttribute(&max_regs_blk, cudaDevAttrMaxRegistersPerBlock, exec->get_device_id()); const int max_threads_regs = - ((max_regs_blk / - static_cast((static_cast(num_regs_used)))) / - warp_sz) * - warp_sz; + ((max_regs_blk / static_cast(num_regs_used)) / warp_sz) * warp_sz; int max_threads = std::min(max_threads_regs, device_max_threads); max_threads = max_threads <= 1024 ? max_threads : 1024; - return std::min(num_warps * warp_sz, max_threads); + return std::max(std::min(num_warps * warp_sz, max_threads), min_block_size); } @@ -136,12 +133,12 @@ using settings = gko::kernels::batch_bicgstab::settings; template -class KernelCaller { +class kernel_caller { public: using value_type = CuValueType; - KernelCaller(std::shared_ptr exec, - const settings> settings) + kernel_caller(std::shared_ptr exec, + const settings> settings) : exec_{std::move(exec)}, settings_{settings} {} @@ -263,6 +260,8 @@ public: sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; + default: + GKO_NOT_IMPLEMENTED; } } @@ -286,7 +285,7 @@ void apply(std::shared_ptr exec, { using cu_value_type = cuda_type; auto dispatcher = batch::solver::create_dispatcher( - KernelCaller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precon); dispatcher.apply(b, x, logdata); } diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index 3068b654b75..c9809696889 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -87,7 +87,7 @@ void scale(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -141,7 +141,7 @@ void add_scaled(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -202,7 +202,7 @@ void compute_dot(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -210,41 +210,37 @@ void compute_dot(std::shared_ptr exec, if (x->get_common_size()[1] == 1) { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto y_b = - batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_dot_sg(x_b.num_rows, x_b.values, - y_b.values, res_b.values[0], - item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_conj_dot_sg(x_b.num_rows, x_b.values, + y_b.values, res_b.values[0], + item_ct1); + }); }); } else { // TODO: Remove reqd_sub_group size and use sycl::reduce_over_group exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto y_b = - batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( - x_b, y_b, res_b, item_ct1, - [](auto val) { return val; }); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return val; }); + }); }); } } @@ -270,7 +266,7 @@ void compute_conj_dot(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -278,19 +274,18 @@ void compute_conj_dot(std::shared_ptr exec, exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto y_b = batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( - x_b, y_b, res_b, item_ct1, - [](auto val) { return conj(val); }); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return conj(val); }); + }); }); } @@ -314,7 +309,7 @@ void compute_norm2(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -322,33 +317,31 @@ void compute_norm2(std::shared_ptr exec, if (x->get_common_size()[1] == 1) { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, - res_b.values[0], item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, + res_b.values[0], item_ct1); + }); }); } else { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_norm2_kernel(x_b, res_b, item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_norm2_kernel(x_b, res_b, item_ct1); + }); }); } } @@ -372,7 +365,7 @@ void copy(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); diff --git a/dpcpp/base/batch_multi_vector_kernels.hpp.inc b/dpcpp/base/batch_multi_vector_kernels.hpp.inc index 4db1dc5e1d7..1fb5684871d 100644 --- a/dpcpp/base/batch_multi_vector_kernels.hpp.inc +++ b/dpcpp/base/batch_multi_vector_kernels.hpp.inc @@ -67,15 +67,12 @@ __dpct_inline__ void add_scaled_kernel( } -template -__dpct_inline__ void single_rhs_compute_dot( +template +__dpct_inline__ void single_rhs_compute_conj_dot( const int num_rows, const ValueType* const __restrict__ x, const ValueType* const __restrict__ y, ValueType& result, sycl::nd_item<3> item_ct1) { - // auto grp = - // group::tiled_partition(group::this_thread_block(item_ct1)); - // auto grp = group::this_thread_block(item_ct1); const auto group = item_ct1.get_group(); const auto group_size = item_ct1.get_local_range().size(); const auto tid = item_ct1.get_local_linear_id(); @@ -90,7 +87,7 @@ __dpct_inline__ void single_rhs_compute_dot( template -__dpct_inline__ void single_rhs_compute_dot_sg( +__dpct_inline__ void single_rhs_compute_conj_dot_sg( const int num_rows, const ValueType* const __restrict__ x, const ValueType* const __restrict__ y, ValueType& result, sycl::nd_item<3> item_ct1) @@ -183,8 +180,6 @@ __dpct_inline__ void single_rhs_compute_norm2( const int num_rows, const ValueType* const __restrict__ x, gko::remove_complex& result, sycl::nd_item<3> item_ct1) { - // auto grp = - // group::tiled_partition(group::this_thread_block(item_ct1)); const auto group = item_ct1.get_group(); const auto group_size = item_ct1.get_local_range().size(); const auto tid = item_ct1.get_local_linear_id(); @@ -197,8 +192,6 @@ __dpct_inline__ void single_rhs_compute_norm2( } val = sycl::reduce_over_group(group, val, sycl::plus<>()); - // val = ::gko::kernels::dpcpp::reduce( - // grp, val, [](real_type a, real_type b) { return a + b; }); result = sqrt(val); } diff --git a/dpcpp/matrix/batch_struct.hpp b/dpcpp/matrix/batch_struct.hpp index fe04407d82d..7f36378d8e1 100644 --- a/dpcpp/matrix/batch_struct.hpp +++ b/dpcpp/matrix/batch_struct.hpp @@ -91,7 +91,7 @@ inline batch::matrix::dense::uniform_batch get_batch_struct( * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch +inline batch::matrix::ell::uniform_batch get_batch_struct(const batch::matrix::Ell* const op) { return {op->get_const_values(), diff --git a/dpcpp/preconditioner/batch_identity.hpp.inc b/dpcpp/preconditioner/batch_identity.hpp.inc index 404d987a3f4..792886f845d 100644 --- a/dpcpp/preconditioner/batch_identity.hpp.inc +++ b/dpcpp/preconditioner/batch_identity.hpp.inc @@ -42,15 +42,9 @@ public: static int dynamic_work_size(int, int) { return 0; } - void generate( - size_type batch_id, - const gko::batch::matrix::ell::batch_item&, - ValueType* const, sycl::nd_item<3> item_ct1) - {} - - void generate(size_type batch_id, - const gko::batch::matrix::dense::batch_item&, - ValueType* const, sycl::nd_item<3> item_ct1) + template + void generate(size_type, const batch_item_type&, ValueType*, + sycl::nd_item<3> item_ct1) {} __dpct_inline__ void apply(const int num_rows, const ValueType* const r, diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 33749e91ae4..52f794cfc0e 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -60,9 +60,9 @@ namespace gko { namespace kernels { namespace dpcpp { /** - * @brief The batch Cg solver namespace. + * @brief The batch Bicgstab solver namespace. * - * @ingroup batch_cg + * @ingroup batch_bicgstab */ namespace batch_bicgstab { @@ -77,10 +77,10 @@ template using settings = gko::kernels::batch_bicgstab::settings; -__dpct_inline__ int get_group_size(int value, int simd_len = 32) +__dpct_inline__ int get_group_size(int value, int subgroup_size = 32) { - int num_sg = ceildiv(value, simd_len); - return num_sg * simd_len; + int num_sg = ceildiv(value, subgroup_size); + return num_sg * subgroup_size; } @@ -92,9 +92,9 @@ class KernelCaller { : exec_{std::move(exec)}, settings_{settings} {} - template + template __dpct_inline__ void launch_apply_kernel( const gko::kernels::batch_bicgstab::storage_config& sconf, LogType& logger, PrecType& prec, const BatchMatrixType mat, @@ -111,15 +111,16 @@ class KernelCaller { auto max_iters = settings_.max_iterations; auto res_tol = settings_.residual_tol; - (exec_->get_queue())->submit([&](sycl::handler& cgh) { + exec_->get_queue()->submit([&](sycl::handler& cgh) { sycl::accessor slm_values(sycl::range<1>(shared_size), cgh); cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( - simd_len)]] [[intel::kernel_args_restrict]] { + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + subgroup_size)]] [ + [intel::kernel_args_restrict]] { auto batch_id = item_ct1.get_group_linear_id(); const auto mat_global_entry = gko::batch::matrix::extract_batch_item(mat, batch_id); @@ -162,10 +163,13 @@ class KernelCaller { // reserve 5 for intermediate rho-s, norms, // alpha, omega, temp and for reduce_over_group // If the value available is negative, then set it to 0 - size_type shmem_per_blk = std::max( - device.get_info() - - (group_size + 5) * sizeof(ValueType) - 2 * sizeof(real_type), - static_cast(0)); + const int static_var_mem = + (group_size + 5) * sizeof(ValueType) - 2 * sizeof(real_type); + int shmem_per_blk = std::max( + static_cast( + device.get_info()) - + static_var_mem, + 0); const int padded_num_rows = num_rows; const size_type prec_size = PrecType::dynamic_work_size( padded_num_rows, mat.get_single_item_num_nnz()); @@ -185,16 +189,17 @@ class KernelCaller { int n_shared_total = sconf.n_shared + int(sconf.prec_shared); // template - // launch_apply_kernel - if (num_rows <= 32 && n_shared_total == 10) + // launch_apply_kernel + if (num_rows <= 32 && n_shared_total == 10) { launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); - else if (num_rows <= 256 && n_shared_total == 10) + } else if (num_rows <= 256 && n_shared_total == 10) { launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, group_size, shared_size); - else { + } else { switch (n_shared_total) { case 0: launch_apply_kernel( diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index 67057f80e53..0b6f4511f02 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -39,6 +39,7 @@ __dpct_inline__ void initialize( ValueType& alpha, ValueType* const x_shared_entry, ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry, ValueType* const p_shared_entry, ValueType* const v_shared_entry, + ValueType* const p_hat_shared_entry, typename gko::remove_complex& rhs_norm, typename gko::remove_complex& res_norm, sycl::nd_item<3> item_ct1) @@ -85,6 +86,7 @@ __dpct_inline__ void initialize( for (int iz = tid; iz < num_rows; iz += group_size) { r_hat_shared_entry[iz] = r_shared_entry[iz]; p_shared_entry[iz] = zero(); + p_hat_shared_entry[iz] = zero(); v_shared_entry[iz] = zero(); } } @@ -115,23 +117,24 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, const ValueType* const v_shared_entry, ValueType& alpha, sycl::nd_item<3> item_ct1) { + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); if constexpr (sg_kernel_all) { - auto sg = item_ct1.get_sub_group(); - const auto sg_id = sg.get_group_id(); - const auto tid = item_ct1.get_local_linear_id(); - if (sg_id == 0) { - single_rhs_compute_dot_sg(num_rows, r_hat_shared_entry, - v_shared_entry, alpha, item_ct1); + single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, + v_shared_entry, alpha, item_ct1); } if (tid == 0) { alpha = rho_new / alpha; } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, r_hat_shared_entry, v_shared_entry, - alpha, item_ct1); - alpha = rho_new / alpha; + single_rhs_compute_conj_dot(num_rows, r_hat_shared_entry, + v_shared_entry, alpha, item_ct1); + if (tid == 0) { + alpha = rho_new / alpha; + } } } @@ -158,26 +161,30 @@ __dpct_inline__ void compute_omega(const int num_rows, ValueType& temp, ValueType& omega, sycl::nd_item<3> item_ct1) { + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); if constexpr (sg_kernel_all) { - auto sg = item_ct1.get_sub_group(); - const auto sg_id = sg.get_group_id(); - const auto tid = item_ct1.get_local_linear_id(); - - if (sg_id == 0) - single_rhs_compute_dot_sg(num_rows, t_shared_entry, s_shared_entry, - omega, item_ct1); - else if (sg_id == 1) - single_rhs_compute_dot_sg(num_rows, t_shared_entry, t_shared_entry, - temp, item_ct1); + if (sg_id == 0) { + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, + s_shared_entry, omega, item_ct1); + } else if (sg_id == 1) { + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, + t_shared_entry, temp, item_ct1); + } item_ct1.barrier(sycl::access::fence_space::local_space); - if (tid == 0) omega /= temp; + if (tid == 0) { + omega /= temp; + } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, t_shared_entry, s_shared_entry, omega, - item_ct1); - single_rhs_compute_dot(num_rows, t_shared_entry, t_shared_entry, temp, - item_ct1); - omega /= temp; + single_rhs_compute_conj_dot(num_rows, t_shared_entry, s_shared_entry, + omega, item_ct1); + single_rhs_compute_conj_dot(num_rows, t_shared_entry, t_shared_entry, + temp, item_ct1); + if (tid == 0) { + omega /= temp; + } } } @@ -244,33 +251,21 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, real_type* norms_rhs_sh; real_type* norms_res_sh; - if constexpr (sg_kernel_all) { - using tile_value_t = ValueType[5]; - tile_value_t& values = - *sycl::ext::oneapi::group_local_memory_for_overwrite( - group); - using tile_real_t = real_type[2]; - tile_real_t& reals = - *sycl::ext::oneapi::group_local_memory_for_overwrite( - group); - rho_old_sh = &values[0]; - rho_new_sh = &values[1]; - alpha_sh = &values[2]; - omega_sh = &values[3]; - temp_sh = &values[4]; - norms_rhs_sh = &reals[0]; - norms_res_sh = &reals[1]; - } else { - ValueType values[5]; - real_type reals[2]; - rho_old_sh = &values[0]; - rho_new_sh = &values[1]; - alpha_sh = &values[2]; - omega_sh = &values[3]; - temp_sh = &values[4]; - norms_rhs_sh = &reals[0]; - norms_res_sh = &reals[1]; - } + using tile_value_t = ValueType[5]; + tile_value_t& values = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + group); + using tile_real_t = real_type[2]; + tile_real_t& reals = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + group); + rho_old_sh = &values[0]; + rho_new_sh = &values[1]; + alpha_sh = &values[2]; + omega_sh = &values[3]; + temp_sh = &values[4]; + norms_rhs_sh = &reals[0]; + norms_res_sh = &reals[1]; const int gmem_offset = batch_id * sconf.gmem_stride_bytes / sizeof(ValueType); ValueType* p_hat_sh; @@ -346,11 +341,12 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // compute residual norms // r_hat = r // p = 0 + // p_hat = 0 // v = 0 initialize(num_rows, mat_global_entry, b_global_entry, x_global_entry, rho_old_sh[0], omega_sh[0], - alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, v_sh, - norms_rhs_sh[0], norms_res_sh[0], item_ct1); + alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, p_hat_sh, + v_sh, norms_rhs_sh[0], norms_res_sh[0], item_ct1); item_ct1.barrier(sycl::access::fence_space::local_space); // stopping criterion object @@ -366,13 +362,13 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // rho_new = < r_hat , r > = (r_hat)' * (r) if constexpr (sg_kernel_all) { if (sg_id == 0) { - single_rhs_compute_dot_sg(num_rows, r_hat_sh, r_sh, - rho_new_sh[0], item_ct1); + single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, + rho_new_sh[0], item_ct1); } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], - item_ct1); + single_rhs_compute_conj_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], + item_ct1); } // beta = (rho_new / rho_old)*(alpha / omega) diff --git a/hip/matrix/batch_struct.hip.hpp b/hip/matrix/batch_struct.hip.hpp index e35f13f1249..ba75b1b634e 100644 --- a/hip/matrix/batch_struct.hip.hpp +++ b/hip/matrix/batch_struct.hip.hpp @@ -92,7 +92,8 @@ get_batch_struct(batch::matrix::Dense* const op) * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch, IndexType> +inline batch::matrix::ell::uniform_batch, + const IndexType> get_batch_struct(const batch::matrix::Ell* const op) { return {as_hip_type(op->get_const_values()), diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index f769a4fe6f5..a7c3667c8ef 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -85,21 +85,19 @@ template int get_num_threads_per_block(std::shared_ptr exec, const int num_rows) { - int nwarps = num_rows / 4; - if (nwarps < 2) { - nwarps = 2; - } - const int min_block_size = 2 * config::warp_size; + int num_warps = std::max(num_rows / 4, 2); + constexpr int warp_sz = static_cast(config::warp_size); + const int min_block_size = 2 * warp_sz; const int device_max_threads = - ((std::max(num_rows, min_block_size)) / config::warp_size) * - config::warp_size; + ((std::max(num_rows, min_block_size)) / warp_sz) * warp_sz; const int num_regs_used_per_thread = 64; int max_regs_blk = 0; hipDeviceGetAttribute(&max_regs_blk, hipDeviceAttributeMaxRegistersPerBlock, exec->get_device_id()); const int max_threads_regs = (max_regs_blk / num_regs_used_per_thread); - const int max_threads = std::min(max_threads_regs, device_max_threads); - return std::min(nwarps * static_cast(config::warp_size), max_threads); + int max_threads = std::min(max_threads_regs, device_max_threads); + max_threads = max_threads <= 1024 ? max_threads : 1024; + return std::max(std::min(num_warps * warp_sz, max_threads), min_block_size); } @@ -108,12 +106,12 @@ using settings = gko::kernels::batch_bicgstab::settings; template -class KernelCaller { +class kernel_caller { public: using value_type = HipValueType; - KernelCaller(std::shared_ptr exec, - const settings> settings) + kernel_caller(std::shared_ptr exec, + const settings> settings) : exec_{exec}, settings_{settings} {} @@ -147,12 +145,11 @@ class KernelCaller { const size_type num_batch_items = mat.num_batch_items; constexpr int align_multiple = 8; const int padded_num_rows = - ((mat.num_rows + align_multiple - 1) / align_multiple) * - align_multiple; + ceildiv(mat.num_rows, align_multiple) * align_multiple; const int shmem_per_blk = exec_->get_max_shared_memory_per_block(); const int block_size = get_num_threads_per_block(exec_, mat.num_rows); - assert(block_size >= 2 * config::warp_size); + GKO_ASSERT(block_size >= 2 * config::warp_size); const size_t prec_size = PrecType::dynamic_work_size(padded_num_rows, @@ -175,62 +172,64 @@ class KernelCaller { // Template parameters launch_apply_kernel( + if (sconf.prec_shared) { + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); - else { + } else { switch (sconf.n_shared) { case 0: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 1: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 2: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 3: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 4: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 5: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 6: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 7: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 8: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; case 9: - launch_apply_kernel( + launch_apply_kernel( sconf, logger, prec, mat, b.values, x.values, workspace_data, block_size, shared_size); break; + default: + GKO_NOT_IMPLEMENTED; } } @@ -254,7 +253,7 @@ void apply(std::shared_ptr exec, { using hip_value_type = hip_type; auto dispatcher = batch::solver::create_dispatcher( - KernelCaller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precon); dispatcher.apply(b, x, logdata); } diff --git a/reference/matrix/batch_struct.hpp b/reference/matrix/batch_struct.hpp index bb7680d1493..94beff5c2c2 100644 --- a/reference/matrix/batch_struct.hpp +++ b/reference/matrix/batch_struct.hpp @@ -95,7 +95,7 @@ inline batch::matrix::dense::uniform_batch get_batch_struct( * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch +inline batch::matrix::ell::uniform_batch get_batch_struct(const batch::matrix::Ell* const op) { return {op->get_const_values(), diff --git a/reference/test/solver/batch_bicgstab_kernels.cpp b/reference/test/solver/batch_bicgstab_kernels.cpp index c47c80e64dc..311fb40e5ef 100644 --- a/reference/test/solver/batch_bicgstab_kernels.cpp +++ b/reference/test/solver/batch_bicgstab_kernels.cpp @@ -108,8 +108,8 @@ TYPED_TEST(BatchBicgstab, SolvesStencilSystem) this->linear_system); for (size_t i = 0; i < this->num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i] / - this->linear_system.rhs_norm->get_const_values()[i], + ASSERT_LE(res.host_res_norm->get_const_values()[i] / + this->linear_system.host_rhs_norm->get_const_values()[i], this->solver_settings.residual_tol); } GKO_ASSERT_BATCH_MTX_NEAR(res.x, this->linear_system.exact_sol, @@ -130,9 +130,10 @@ TYPED_TEST(BatchBicgstab, StencilSystemLoggerLogsResidual) auto iter_array = res.log_data->iter_counts.get_const_data(); auto res_log_array = res.log_data->res_norms.get_const_data(); for (size_t i = 0; i < this->num_batch_items; i++) { - ASSERT_LE(res_log_array[i] / this->linear_system.rhs_norm->at(i, 0, 0), - this->solver_settings.residual_tol); - ASSERT_NEAR(res_log_array[i], res.res_norm->get_const_values()[i], + ASSERT_LE( + res_log_array[i] / this->linear_system.host_rhs_norm->at(i, 0, 0), + this->solver_settings.residual_tol); + ASSERT_NEAR(res_log_array[i], res.host_res_norm->get_const_values()[i], 10 * this->eps); } } @@ -186,8 +187,8 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i], + ASSERT_LE(res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i], tol); } } @@ -228,8 +229,8 @@ TYPED_TEST(BatchBicgstab, ApplyLogsResAndIters) auto res_norm = logger->get_residual_norm(); GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); for (size_t i = 0; i < num_batch_items; i++) { - auto rel_res_norm = res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i]; + auto rel_res_norm = res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i]; ASSERT_LE(iter_counts.get_const_data()[i], max_iters); EXPECT_LE(res_norm.get_const_data()[i], tol * 50); ASSERT_LE(rel_res_norm, tol * 50); @@ -266,8 +267,8 @@ TYPED_TEST(BatchBicgstab, CanSolveEllSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i], + ASSERT_LE(res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i], tol * 10); } } @@ -302,6 +303,6 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseHpdSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i], tol * 50); + ASSERT_LE(res.host_res_norm->get_const_values()[i], tol * 50); } } diff --git a/test/solver/batch_bicgstab_kernels.cpp b/test/solver/batch_bicgstab_kernels.cpp index 124dd27640c..ea5e7ec782f 100644 --- a/test/solver/batch_bicgstab_kernels.cpp +++ b/test/solver/batch_bicgstab_kernels.cpp @@ -117,8 +117,8 @@ TEST_F(BatchBicgstab, SolvesStencilSystem) solver_settings, linear_system); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i], + ASSERT_LE(res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i], solver_settings.residual_tol); } GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol); @@ -141,9 +141,9 @@ TEST_F(BatchBicgstab, StencilSystemLoggerLogsResidual) auto res_log_array = res.log_data->res_norms.get_const_data(); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res_log_array[i] / linear_system.rhs_norm->at(i, 0, 0), + ASSERT_LE(res_log_array[i] / linear_system.host_rhs_norm->at(i, 0, 0), solver_settings.residual_tol); - ASSERT_NEAR(res_log_array[i], res.res_norm->get_const_values()[i], + ASSERT_NEAR(res_log_array[i], res.host_res_norm->get_const_values()[i], 10 * tol); } } @@ -185,8 +185,8 @@ TEST_F(BatchBicgstab, CanSolve3ptStencilSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); for (size_t i = 0; i < num_batch_items; i++) { - auto comp_res_norm = res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i]; + auto comp_res_norm = res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i]; ASSERT_LE(comp_res_norm, tol); } } @@ -215,11 +215,11 @@ TEST_F(BatchBicgstab, CanSolveLargeBatchSizeHpdSystem) &logger->get_residual_norm()); GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); for (size_t i = 0; i < num_batch_items; i++) { - auto comp_res_norm = res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i]; + auto comp_res_norm = res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i]; ASSERT_LE(iter_counts->get_const_data()[i], max_iters); EXPECT_LE(res_norm->get_const_data()[i] / - linear_system.rhs_norm->get_const_values()[i], + linear_system.host_rhs_norm->get_const_values()[i], tol); EXPECT_GT(res_norm->get_const_data()[i], real_type{0.0}); ASSERT_LE(comp_res_norm, tol); @@ -250,11 +250,11 @@ TEST_F(BatchBicgstab, CanSolveLargeMatrixSizeHpdSystem) &logger->get_residual_norm()); GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); for (size_t i = 0; i < num_batch_items; i++) { - auto comp_res_norm = res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i]; + auto comp_res_norm = res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i]; ASSERT_LE(iter_counts->get_const_data()[i], max_iters); EXPECT_LE(res_norm->get_const_data()[i] / - linear_system.rhs_norm->get_const_values()[i], + linear_system.host_rhs_norm->get_const_values()[i], tol); EXPECT_GT(res_norm->get_const_data()[i], real_type{0.0}); ASSERT_LE(comp_res_norm, tol);